From 1ae5654eba0163759388df904b0fe17474e786aa Mon Sep 17 00:00:00 2001 From: JT Olio Date: Thu, 13 Jun 2019 08:51:50 -0600 Subject: [PATCH] kademlia/routing: add contexts to more places so monkit works (#2188) --- pkg/dht/dht.go | 12 +- pkg/kademlia/endpoint.go | 6 +- pkg/kademlia/inspector.go | 6 +- pkg/kademlia/kademlia.go | 24 +-- pkg/kademlia/kademlia_test.go | 18 +-- pkg/kademlia/replacement_cache_test.go | 4 +- pkg/kademlia/routing.go | 40 ++--- pkg/kademlia/routing_helpers.go | 67 ++++----- pkg/kademlia/routing_helpers_test.go | 83 ++++++----- pkg/kademlia/routing_integration_test.go | 181 ++++++++++++----------- pkg/kademlia/routing_test.go | 40 ++--- pkg/kademlia/testrouting/testrouting.go | 22 ++- pkg/kademlia/testrouting/utils.go | 4 +- 13 files changed, 254 insertions(+), 253 deletions(-) diff --git a/pkg/dht/dht.go b/pkg/dht/dht.go index 6234525e7..89475e0f1 100644 --- a/pkg/dht/dht.go +++ b/pkg/dht/dht.go @@ -28,13 +28,13 @@ type RoutingTable interface { Local() overlay.NodeDossier K() int CacheSize() int - GetBucketIds() (storage.Keys, error) - FindNear(id storj.NodeID, limit int) ([]*pb.Node, error) - ConnectionSuccess(node *pb.Node) error - ConnectionFailed(node *pb.Node) error + GetBucketIds(context.Context) (storage.Keys, error) + FindNear(ctx context.Context, id storj.NodeID, limit int) ([]*pb.Node, error) + ConnectionSuccess(ctx context.Context, node *pb.Node) error + ConnectionFailed(ctx context.Context, node *pb.Node) error // these are for refreshing - SetBucketTimestamp(id []byte, now time.Time) error - GetBucketTimestamp(id []byte) (time.Time, error) + SetBucketTimestamp(ctx context.Context, id []byte, now time.Time) error + GetBucketTimestamp(ctx context.Context, id []byte) (time.Time, error) Close() error } diff --git a/pkg/kademlia/endpoint.go b/pkg/kademlia/endpoint.go index a1a7fc1aa..a66e26593 100644 --- a/pkg/kademlia/endpoint.go +++ b/pkg/kademlia/endpoint.go @@ -42,7 +42,7 @@ func (endpoint *Endpoint) Query(ctx context.Context, req *pb.QueryRequest) (_ *p endpoint.pingback(ctx, req.Sender) } - nodes, err := endpoint.routingTable.FindNear(req.Target.Id, int(req.Limit)) + nodes, err := endpoint.routingTable.FindNear(ctx, req.Target.Id, int(req.Limit)) if err != nil { return &pb.QueryResponse{}, EndpointError.New("could not find near endpoint: %v", err) } @@ -57,12 +57,12 @@ func (endpoint *Endpoint) pingback(ctx context.Context, target *pb.Node) { _, err = endpoint.service.Ping(ctx, *target) if err != nil { endpoint.log.Debug("connection to node failed", zap.Error(err), zap.String("nodeID", target.Id.String())) - err = endpoint.routingTable.ConnectionFailed(target) + err = endpoint.routingTable.ConnectionFailed(ctx, target) if err != nil { endpoint.log.Error("could not respond to connection failed", zap.Error(err)) } } else { - err = endpoint.routingTable.ConnectionSuccess(target) + err = endpoint.routingTable.ConnectionSuccess(ctx, target) if err != nil { endpoint.log.Error("could not respond to connection success", zap.Error(err)) } else { diff --git a/pkg/kademlia/inspector.go b/pkg/kademlia/inspector.go index d8327560f..cbc457c11 100644 --- a/pkg/kademlia/inspector.go +++ b/pkg/kademlia/inspector.go @@ -43,7 +43,7 @@ func (srv *Inspector) CountNodes(ctx context.Context, req *pb.CountNodesRequest) // GetBuckets returns all kademlia buckets for current kademlia instance func (srv *Inspector) GetBuckets(ctx context.Context, req *pb.GetBucketsRequest) (_ *pb.GetBucketsResponse, err error) { defer mon.Task()(&ctx)(&err) - b, err := srv.dht.GetBucketIds() + b, err := srv.dht.GetBucketIds(ctx) if err != nil { return nil, err } @@ -142,7 +142,7 @@ func (srv *Inspector) NodeInfo(ctx context.Context, req *pb.NodeInfoRequest) (_ // GetBucketList returns the list of buckets with their routing nodes and their cached nodes func (srv *Inspector) GetBucketList(ctx context.Context, req *pb.GetBucketListRequest) (_ *pb.GetBucketListResponse, err error) { defer mon.Task()(&ctx)(&err) - bucketIds, err := srv.dht.GetBucketIds() + bucketIds, err := srv.dht.GetBucketIds(ctx) if err != nil { return nil, err } @@ -151,7 +151,7 @@ func (srv *Inspector) GetBucketList(ctx context.Context, req *pb.GetBucketListRe for i, b := range bucketIds { bucketID := keyToBucketID(b) - routingNodes, err := srv.dht.GetNodesWithinKBucket(bucketID) + routingNodes, err := srv.dht.GetNodesWithinKBucket(ctx, bucketID) if err != nil { return nil, err } diff --git a/pkg/kademlia/kademlia.go b/pkg/kademlia/kademlia.go index c738c36e2..cf33bff25 100644 --- a/pkg/kademlia/kademlia.go +++ b/pkg/kademlia/kademlia.go @@ -120,12 +120,13 @@ func (k *Kademlia) Queried() { // stored in the local routing table. func (k *Kademlia) FindNear(ctx context.Context, start storj.NodeID, limit int) (_ []*pb.Node, err error) { defer mon.Task()(&ctx)(&err) - return k.routingTable.FindNear(start, limit) + return k.routingTable.FindNear(ctx, start, limit) } // GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance -func (k *Kademlia) GetBucketIds() (storage.Keys, error) { - return k.routingTable.GetBucketIds() +func (k *Kademlia) GetBucketIds(ctx context.Context) (_ storage.Keys, err error) { + defer mon.Task()(&ctx)(&err) + return k.routingTable.GetBucketIds(ctx) } // Local returns the local node @@ -141,8 +142,9 @@ func (k *Kademlia) SetBootstrapNodes(nodes []pb.Node) { k.bootstrapNodes = nodes func (k *Kademlia) GetBootstrapNodes() []pb.Node { return k.bootstrapNodes } // DumpNodes returns all the nodes in the node database -func (k *Kademlia) DumpNodes(ctx context.Context) ([]*pb.Node, error) { - return k.routingTable.DumpNodes() +func (k *Kademlia) DumpNodes(ctx context.Context) (_ []*pb.Node, err error) { + defer mon.Task()(&ctx)(&err) + return k.routingTable.DumpNodes(ctx) } // Bootstrap contacts one of a set of pre defined trusted nodes on the network and @@ -298,7 +300,7 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap } } else { var err error - nodes, err = k.routingTable.FindNear(nodeID, kb) + nodes, err = k.routingTable.FindNear(ctx, nodeID, kb) if err != nil { return pb.Node{}, err } @@ -314,7 +316,7 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap if err != nil { k.log.Warn("Error getting getKBucketID in kad lookup") } else { - err = k.routingTable.SetBucketTimestamp(bucket[:], time.Now()) + err = k.routingTable.SetBucketTimestamp(ctx, bucket[:], time.Now()) if err != nil { k.log.Warn("Error updating bucket timestamp in kad lookup") } @@ -329,8 +331,8 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap } // GetNodesWithinKBucket returns all the routing nodes in the specified k-bucket -func (k *Kademlia) GetNodesWithinKBucket(bID bucketID) ([]*pb.Node, error) { - return k.routingTable.getUnmarshaledNodesFromBucket(bID) +func (k *Kademlia) GetNodesWithinKBucket(ctx context.Context, bID bucketID) (_ []*pb.Node, err error) { + return k.routingTable.getUnmarshaledNodesFromBucket(ctx, bID) } // GetCachedNodesWithinKBucket returns all the cached nodes in the specified k-bucket @@ -366,7 +368,7 @@ func (k *Kademlia) Run(ctx context.Context) (err error) { // refresh updates each Kademlia bucket not contacted in the last hour func (k *Kademlia) refresh(ctx context.Context, threshold time.Duration) (err error) { defer mon.Task()(&ctx)(&err) - bIDs, err := k.routingTable.GetBucketIds() + bIDs, err := k.routingTable.GetBucketIds(ctx) if err != nil { return Error.Wrap(err) } @@ -375,7 +377,7 @@ func (k *Kademlia) refresh(ctx context.Context, threshold time.Duration) (err er var errors errs.Group for _, bID := range bIDs { endID := keyToBucketID(bID) - ts, tErr := k.routingTable.GetBucketTimestamp(bID) + ts, tErr := k.routingTable.GetBucketTimestamp(ctx, bID) if tErr != nil { errors.Add(tErr) } else if now.After(ts.Add(threshold)) { diff --git a/pkg/kademlia/kademlia_test.go b/pkg/kademlia/kademlia_test.go index 478663983..1c52fdc36 100644 --- a/pkg/kademlia/kademlia_test.go +++ b/pkg/kademlia/kademlia_test.go @@ -66,7 +66,7 @@ func TestNewKademlia(t *testing.T) { } for i, v := range cases { - kad, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, v.bn, v.addr, pb.NodeOperator{}, v.id, ctx.Dir(strconv.Itoa(i)), defaultAlpha) + kad, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, v.bn, v.addr, pb.NodeOperator{}, v.id, ctx.Dir(strconv.Itoa(i)), defaultAlpha) require.NoError(t, err) assert.Equal(t, v.expectedErr, err) assert.Equal(t, kad.bootstrapNodes, v.bn) @@ -93,7 +93,7 @@ func TestPeerDiscovery(t *testing.T) { operator := pb.NodeOperator{ Wallet: "OperatorWallet", } - k, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrapNodes, testAddress, operator, testID, ctx.Dir("test"), defaultAlpha) + k, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrapNodes, testAddress, operator, testID, ctx.Dir("test"), defaultAlpha) require.NoError(t, err) rt := k.routingTable assert.Equal(t, rt.Local().Operator.Wallet, "OperatorWallet") @@ -161,7 +161,7 @@ func testNode(ctx *testcontext.Context, name string, t *testing.T, bn []pb.Node) // new kademlia logger := zaptest.NewLogger(t) - k, err := newKademlia(logger, pb.NodeType_STORAGE, bn, lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir(name), defaultAlpha) + k, err := newKademlia(ctx, logger, pb.NodeType_STORAGE, bn, lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir(name), defaultAlpha) require.NoError(t, err) s := NewEndpoint(logger, k, k.routingTable) @@ -200,18 +200,18 @@ func TestRefresh(t *testing.T) { rt := k.routingTable now := time.Now().UTC() bID := firstBucketID //always exists - err := rt.SetBucketTimestamp(bID[:], now.Add(-2*time.Hour)) + err := rt.SetBucketTimestamp(ctx, bID[:], now.Add(-2*time.Hour)) require.NoError(t, err) //refresh should call FindNode, updating the time err = k.refresh(ctx, time.Minute) require.NoError(t, err) - ts1, err := rt.GetBucketTimestamp(bID[:]) + ts1, err := rt.GetBucketTimestamp(ctx, bID[:]) require.NoError(t, err) assert.True(t, now.Add(-5*time.Minute).Before(ts1)) //refresh should not call FindNode, leaving the previous time err = k.refresh(ctx, time.Minute) require.NoError(t, err) - ts2, err := rt.GetBucketTimestamp(bID[:]) + ts2, err := rt.GetBucketTimestamp(ctx, bID[:]) require.NoError(t, err) assert.True(t, ts1.Equal(ts2)) s.GracefulStop() @@ -243,7 +243,7 @@ func TestFindNear(t *testing.T) { }) bootstrap := []pb.Node{{Id: fid2.ID, Address: &pb.NodeAddress{Address: lis.Addr().String()}}} - k, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrap, + k, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrap, lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir("kademlia"), defaultAlpha) require.NoError(t, err) defer ctx.Check(k.Close) @@ -254,7 +254,7 @@ func TestFindNear(t *testing.T) { nodeID := teststorj.NodeIDFromString(id) n := &pb.Node{Id: nodeID} nodes = append(nodes, n) - err = k.routingTable.ConnectionSuccess(n) + err = k.routingTable.ConnectionSuccess(ctx, n) require.NoError(t, err) return *n } @@ -408,7 +408,7 @@ func (mn *mockNodesServer) RequestInfo(ctx context.Context, req *pb.InfoRequest) } // newKademlia returns a newly configured Kademlia instance -func newKademlia(log *zap.Logger, nodeType pb.NodeType, bootstrapNodes []pb.Node, address string, operator pb.NodeOperator, identity *identity.FullIdentity, path string, alpha int) (*Kademlia, error) { +func newKademlia(ctx context.Context, log *zap.Logger, nodeType pb.NodeType, bootstrapNodes []pb.Node, address string, operator pb.NodeOperator, identity *identity.FullIdentity, path string, alpha int) (*Kademlia, error) { self := &overlay.NodeDossier{ Node: pb.Node{ Id: identity.ID, diff --git a/pkg/kademlia/replacement_cache_test.go b/pkg/kademlia/replacement_cache_test.go index fc44afc30..7acaeaeab 100644 --- a/pkg/kademlia/replacement_cache_test.go +++ b/pkg/kademlia/replacement_cache_test.go @@ -17,7 +17,7 @@ import ( func TestAddToReplacementCache(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(storj.NodeID{244, 255}) + rt := createRoutingTable(ctx, storj.NodeID{244, 255}) defer ctx.Check(rt.Close) kadBucketID := bucketID{255, 255} @@ -40,7 +40,7 @@ func TestAddToReplacementCache(t *testing.T) { func TestRemoveFromReplacementCache(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTableWith(storj.NodeID{244, 255}, routingTableOpts{cacheSize: 3}) + rt := createRoutingTableWith(ctx, storj.NodeID{244, 255}, routingTableOpts{cacheSize: 3}) defer ctx.Check(rt.Close) kadBucketID2 := bucketID{127, 255} diff --git a/pkg/kademlia/routing.go b/pkg/kademlia/routing.go index 4e3d30462..74b091419 100644 --- a/pkg/kademlia/routing.go +++ b/pkg/kademlia/routing.go @@ -68,7 +68,7 @@ type RoutingTable struct { } // NewRoutingTable returns a newly configured instance of a RoutingTable -func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, ndb storage.KeyValueStore, config *RoutingTableConfig) (*RoutingTable, error) { +func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, ndb storage.KeyValueStore, config *RoutingTableConfig) (_ *RoutingTable, err error) { if config == nil || config.BucketSize == 0 || config.ReplacementCacheSize == 0 { // TODO: handle this more nicely config = &RoutingTableConfig{ @@ -91,7 +91,7 @@ func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, nd bucketSize: config.BucketSize, rcBucketSize: config.ReplacementCacheSize, } - ok, err := rt.addNode(&localNode.Node) + ok, err := rt.addNode(context.TODO(), &localNode.Node) if !ok || err != nil { return nil, RoutingErr.New("could not add localNode to routing table: %s", err) } @@ -131,8 +131,7 @@ func (rt *RoutingTable) CacheSize() int { // GetNodes retrieves nodes within the same kbucket as the given node id // Note: id doesn't need to be stored at time of search -func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) { - ctx := context.TODO() +func (rt *RoutingTable) GetNodes(ctx context.Context, id storj.NodeID) ([]*pb.Node, bool) { defer mon.Task()(&ctx)(nil) bID, err := rt.getKBucketID(ctx, id) if err != nil { @@ -141,7 +140,7 @@ func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) { if bID == (bucketID{}) { return nil, false } - unmarshaledNodes, err := rt.getUnmarshaledNodesFromBucket(bID) + unmarshaledNodes, err := rt.getUnmarshaledNodesFromBucket(ctx, bID) if err != nil { return nil, false } @@ -149,8 +148,7 @@ func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) { } // GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance -func (rt *RoutingTable) GetBucketIds() (_ storage.Keys, err error) { - ctx := context.TODO() +func (rt *RoutingTable) GetBucketIds(ctx context.Context) (_ storage.Keys, err error) { defer mon.Task()(&ctx)(&err) kbuckets, err := rt.kadBucketDB.List(ctx, nil, 0) @@ -161,8 +159,7 @@ func (rt *RoutingTable) GetBucketIds() (_ storage.Keys, err error) { } // DumpNodes iterates through all nodes in the nodeBucketDB and marshals them to &pb.Nodes, then returns them -func (rt *RoutingTable) DumpNodes() (_ []*pb.Node, err error) { - ctx := context.TODO() +func (rt *RoutingTable) DumpNodes(ctx context.Context) (_ []*pb.Node, err error) { defer mon.Task()(&ctx)(&err) var nodes []*pb.Node @@ -187,8 +184,7 @@ func (rt *RoutingTable) DumpNodes() (_ []*pb.Node, err error) { // FindNear returns the node corresponding to the provided nodeID // returns all Nodes (excluding self) closest via XOR to the provided nodeID up to the provided limit -func (rt *RoutingTable) FindNear(target storj.NodeID, limit int) (_ []*pb.Node, err error) { - ctx := context.TODO() +func (rt *RoutingTable) FindNear(ctx context.Context, target storj.NodeID, limit int) (_ []*pb.Node, err error) { defer mon.Task()(&ctx)(&err) closestNodes := make([]*pb.Node, 0, limit+1) err = rt.iterateNodes(ctx, storj.NodeID{}, func(ctx context.Context, newID storj.NodeID, protoNode []byte) error { @@ -217,8 +213,7 @@ func (rt *RoutingTable) FindNear(target storj.NodeID, limit int) (_ []*pb.Node, // ConnectionSuccess updates or adds a node to the routing table when // a successful connection is made to the node on the network -func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) ConnectionSuccess(ctx context.Context, node *pb.Node) (err error) { defer mon.Task()(&ctx)(&err) // valid to connect to node without ID but don't store connection if node.Id == (storj.NodeID{}) { @@ -230,13 +225,13 @@ func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) { return RoutingErr.New("could not get node %s", err) } if v != nil { - err = rt.updateNode(node) + err = rt.updateNode(ctx, node) if err != nil { return RoutingErr.New("could not update node %s", err) } return nil } - _, err = rt.addNode(node) + _, err = rt.addNode(ctx, node) if err != nil { return RoutingErr.New("could not add node %s", err) } @@ -246,10 +241,9 @@ func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) { // ConnectionFailed removes a node from the routing table when // a connection fails for the node on the network -func (rt *RoutingTable) ConnectionFailed(node *pb.Node) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) ConnectionFailed(ctx context.Context, node *pb.Node) (err error) { defer mon.Task()(&ctx)(&err) - err = rt.removeNode(node) + err = rt.removeNode(ctx, node) if err != nil { return RoutingErr.New("could not remove node %s", err) } @@ -257,8 +251,7 @@ func (rt *RoutingTable) ConnectionFailed(node *pb.Node) (err error) { } // SetBucketTimestamp records the time of the last node lookup for a bucket -func (rt *RoutingTable) SetBucketTimestamp(bIDBytes []byte, now time.Time) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) SetBucketTimestamp(ctx context.Context, bIDBytes []byte, now time.Time) (err error) { defer mon.Task()(&ctx)(&err) rt.mutex.Lock() defer rt.mutex.Unlock() @@ -270,8 +263,7 @@ func (rt *RoutingTable) SetBucketTimestamp(bIDBytes []byte, now time.Time) (err } // GetBucketTimestamp retrieves time of the last node lookup for a bucket -func (rt *RoutingTable) GetBucketTimestamp(bIDBytes []byte) (_ time.Time, err error) { - ctx := context.TODO() +func (rt *RoutingTable) GetBucketTimestamp(ctx context.Context, bIDBytes []byte) (_ time.Time, err error) { defer mon.Task()(&ctx)(&err) t, err := rt.kadBucketDB.Get(ctx, bIDBytes) @@ -307,7 +299,7 @@ func (rt *RoutingTable) iterateNodes(ctx context.Context, start storj.NodeID, f // ConnFailure implements the Transport failure function func (rt *RoutingTable) ConnFailure(ctx context.Context, node *pb.Node, err error) { - err2 := rt.ConnectionFailed(node) + err2 := rt.ConnectionFailed(ctx, node) if err2 != nil { zap.L().Debug(fmt.Sprintf("error with ConnFailure hook %+v : %+v", err, err2)) } @@ -315,7 +307,7 @@ func (rt *RoutingTable) ConnFailure(ctx context.Context, node *pb.Node, err erro // ConnSuccess implements the Transport success function func (rt *RoutingTable) ConnSuccess(ctx context.Context, node *pb.Node) { - err := rt.ConnectionSuccess(node) + err := rt.ConnectionSuccess(ctx, node) if err != nil { zap.L().Debug("connection success error:", zap.Error(err)) } diff --git a/pkg/kademlia/routing_helpers.go b/pkg/kademlia/routing_helpers.go index 7831db146..6f8b682a2 100644 --- a/pkg/kademlia/routing_helpers.go +++ b/pkg/kademlia/routing_helpers.go @@ -18,8 +18,7 @@ import ( // addNode attempts to add a new contact to the routing table // Requires node not already in table // Returns true if node was added successfully -func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { - ctx := context.TODO() +func (rt *RoutingTable) addNode(ctx context.Context, node *pb.Node) (_ bool, err error) { defer mon.Task()(&ctx)(&err) rt.mutex.Lock() defer rt.mutex.Unlock() @@ -29,7 +28,7 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { if err != nil { return false, RoutingErr.New("could not create initial K bucket: %s", err) } - err = rt.putNode(node) + err = rt.putNode(ctx, node) if err != nil { return false, RoutingErr.New("could not add initial node to nodeBucketDB: %s", err) } @@ -39,22 +38,22 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { if err != nil { return false, RoutingErr.New("could not getKBucketID: %s", err) } - hasRoom, err := rt.kadBucketHasRoom(kadBucketID) + hasRoom, err := rt.kadBucketHasRoom(ctx, kadBucketID) if err != nil { return false, err } - containsLocal, err := rt.kadBucketContainsLocalNode(kadBucketID) + containsLocal, err := rt.kadBucketContainsLocalNode(ctx, kadBucketID) if err != nil { return false, err } - withinK, err := rt.wouldBeInNearestK(node.Id) + withinK, err := rt.wouldBeInNearestK(ctx, node.Id) if err != nil { return false, RoutingErr.New("could not determine if node is within k: %s", err) } for !hasRoom { if containsLocal || withinK { - depth, err := rt.determineLeafDepth(kadBucketID) + depth, err := rt.determineLeafDepth(ctx, kadBucketID) if err != nil { return false, RoutingErr.New("could not determine leaf depth: %s", err) } @@ -67,11 +66,11 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { if err != nil { return false, RoutingErr.New("could not get k bucket Id within add node split bucket checks: %s", err) } - hasRoom, err = rt.kadBucketHasRoom(kadBucketID) + hasRoom, err = rt.kadBucketHasRoom(ctx, kadBucketID) if err != nil { return false, err } - containsLocal, err = rt.kadBucketContainsLocalNode(kadBucketID) + containsLocal, err = rt.kadBucketContainsLocalNode(ctx, kadBucketID) if err != nil { return false, err } @@ -81,7 +80,7 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { return false, nil } } - err = rt.putNode(node) + err = rt.putNode(ctx, node) if err != nil { return false, RoutingErr.New("could not add node to nodeBucketDB: %s", err) } @@ -94,18 +93,16 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) { // updateNode will update the node information given that // the node is already in the routing table. -func (rt *RoutingTable) updateNode(node *pb.Node) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) updateNode(ctx context.Context, node *pb.Node) (err error) { defer mon.Task()(&ctx)(&err) - if err := rt.putNode(node); err != nil { + if err := rt.putNode(ctx, node); err != nil { return RoutingErr.New("could not update node: %v", err) } return nil } // removeNode will remove churned nodes and replace those entries with nodes from the replacement cache. -func (rt *RoutingTable) removeNode(node *pb.Node) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) removeNode(ctx context.Context, node *pb.Node) (err error) { defer mon.Task()(&ctx)(&err) rt.mutex.Lock() defer rt.mutex.Unlock() @@ -142,7 +139,7 @@ func (rt *RoutingTable) removeNode(node *pb.Node) (err error) { if len(nodes) == 0 { return nil } - err = rt.putNode(nodes[len(nodes)-1]) + err = rt.putNode(ctx, nodes[len(nodes)-1]) if err != nil { return err } @@ -152,8 +149,7 @@ func (rt *RoutingTable) removeNode(node *pb.Node) (err error) { } // putNode: helper, adds or updates Node and ID to nodeBucketDB -func (rt *RoutingTable) putNode(node *pb.Node) (err error) { - ctx := context.TODO() +func (rt *RoutingTable) putNode(ctx context.Context, node *pb.Node) (err error) { defer mon.Task()(&ctx)(&err) v, err := proto.Marshal(node) if err != nil { @@ -203,8 +199,9 @@ func (rt *RoutingTable) getKBucketID(ctx context.Context, nodeID storj.NodeID) ( } // wouldBeInNearestK: helper, returns true if the node in question is within the nearest k from local node -func (rt *RoutingTable) wouldBeInNearestK(nodeID storj.NodeID) (bool, error) { - closestNodes, err := rt.FindNear(rt.self.Id, rt.bucketSize) +func (rt *RoutingTable) wouldBeInNearestK(ctx context.Context, nodeID storj.NodeID) (_ bool, err error) { + defer mon.Task()(&ctx)(&err) + closestNodes, err := rt.FindNear(ctx, rt.self.Id, rt.bucketSize) if err != nil { return false, RoutingErr.Wrap(err) } @@ -224,8 +221,7 @@ func (rt *RoutingTable) wouldBeInNearestK(nodeID storj.NodeID) (bool, error) { } // kadBucketContainsLocalNode returns true if the kbucket in question contains the local node -func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (_ bool, err error) { - ctx := context.TODO() +func (rt *RoutingTable) kadBucketContainsLocalNode(ctx context.Context, queryID bucketID) (_ bool, err error) { defer mon.Task()(&ctx)(&err) bID, err := rt.getKBucketID(ctx, rt.self.Id) if err != nil { @@ -235,8 +231,8 @@ func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (_ bool, er } // kadBucketHasRoom: helper, returns true if it has fewer than k nodes -func (rt *RoutingTable) kadBucketHasRoom(bID bucketID) (bool, error) { - nodes, err := rt.getNodeIDsWithinKBucket(bID) +func (rt *RoutingTable) kadBucketHasRoom(ctx context.Context, bID bucketID) (_ bool, err error) { + nodes, err := rt.getNodeIDsWithinKBucket(ctx, bID) if err != nil { return false, err } @@ -247,10 +243,9 @@ func (rt *RoutingTable) kadBucketHasRoom(bID bucketID) (bool, error) { } // getNodeIDsWithinKBucket: helper, returns a collection of all the node ids contained within the kbucket -func (rt *RoutingTable) getNodeIDsWithinKBucket(bID bucketID) (_ storj.NodeIDList, err error) { - ctx := context.TODO() +func (rt *RoutingTable) getNodeIDsWithinKBucket(ctx context.Context, bID bucketID) (_ storj.NodeIDList, err error) { defer mon.Task()(&ctx)(&err) - endpoints, err := rt.getKBucketRange(bID) + endpoints, err := rt.getKBucketRange(ctx, bID) if err != nil { return nil, err } @@ -271,8 +266,7 @@ func (rt *RoutingTable) getNodeIDsWithinKBucket(bID bucketID) (_ storj.NodeIDLis } // getNodesFromIDsBytes: helper, returns array of encoded nodes from node ids -func (rt *RoutingTable) getNodesFromIDsBytes(nodeIDs storj.NodeIDList) (_ []*pb.Node, err error) { - ctx := context.TODO() +func (rt *RoutingTable) getNodesFromIDsBytes(ctx context.Context, nodeIDs storj.NodeIDList) (_ []*pb.Node, err error) { defer mon.Task()(&ctx)(&err) var marshaledNodes []storage.Value for _, v := range nodeIDs { @@ -300,12 +294,13 @@ func unmarshalNodes(nodes []storage.Value) ([]*pb.Node, error) { } // getUnmarshaledNodesFromBucket: helper, gets nodes within kbucket -func (rt *RoutingTable) getUnmarshaledNodesFromBucket(bID bucketID) ([]*pb.Node, error) { - nodeIDsBytes, err := rt.getNodeIDsWithinKBucket(bID) +func (rt *RoutingTable) getUnmarshaledNodesFromBucket(ctx context.Context, bID bucketID) (_ []*pb.Node, err error) { + defer mon.Task()(&ctx)(&err) + nodeIDsBytes, err := rt.getNodeIDsWithinKBucket(ctx, bID) if err != nil { return []*pb.Node{}, RoutingErr.New("could not get nodeIds within kbucket %s", err) } - nodes, err := rt.getNodesFromIDsBytes(nodeIDsBytes) + nodes, err := rt.getNodesFromIDsBytes(ctx, nodeIDsBytes) if err != nil { return []*pb.Node{}, RoutingErr.New("could not get node values %s", err) } @@ -313,8 +308,7 @@ func (rt *RoutingTable) getUnmarshaledNodesFromBucket(bID bucketID) ([]*pb.Node, } // getKBucketRange: helper, returns the left and right endpoints of the range of node ids contained within the bucket -func (rt *RoutingTable) getKBucketRange(bID bucketID) (_ []bucketID, err error) { - ctx := context.TODO() +func (rt *RoutingTable) getKBucketRange(ctx context.Context, bID bucketID) (_ []bucketID, err error) { defer mon.Task()(&ctx)(&err) previousBucket := bucketID{} endpoints := []bucketID{} @@ -340,8 +334,9 @@ func (rt *RoutingTable) getKBucketRange(bID bucketID) (_ []bucketID, err error) // determineLeafDepth determines the level of the bucket id in question. // Eg level 0 means there is only 1 bucket, level 1 means the bucket has been split once, and so on -func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) { - bucketRange, err := rt.getKBucketRange(bID) +func (rt *RoutingTable) determineLeafDepth(ctx context.Context, bID bucketID) (_ int, err error) { + defer mon.Task()(&ctx)(&err) + bucketRange, err := rt.getKBucketRange(ctx, bID) if err != nil { return -1, RoutingErr.New("could not get k bucket range %s", err) } diff --git a/pkg/kademlia/routing_helpers_test.go b/pkg/kademlia/routing_helpers_test.go index 7e0fcaa25..0c6a8931f 100644 --- a/pkg/kademlia/routing_helpers_test.go +++ b/pkg/kademlia/routing_helpers_test.go @@ -5,6 +5,7 @@ package kademlia import ( "bytes" + "context" "sync" "testing" "time" @@ -30,7 +31,7 @@ type routingTableOpts struct { } // newTestRoutingTable returns a newly configured instance of a RoutingTable -func newTestRoutingTable(local *overlay.NodeDossier, opts routingTableOpts) (*RoutingTable, error) { +func newTestRoutingTable(ctx context.Context, local *overlay.NodeDossier, opts routingTableOpts) (*RoutingTable, error) { if opts.bucketSize == 0 { opts.bucketSize = 6 } @@ -50,34 +51,34 @@ func newTestRoutingTable(local *overlay.NodeDossier, opts routingTableOpts) (*Ro bucketSize: opts.bucketSize, rcBucketSize: opts.cacheSize, } - ok, err := rt.addNode(&local.Node) + ok, err := rt.addNode(ctx, &local.Node) if !ok || err != nil { return nil, RoutingErr.New("could not add localNode to routing table: %s", err) } return rt, nil } -func createRoutingTableWith(localNodeID storj.NodeID, opts routingTableOpts) *RoutingTable { +func createRoutingTableWith(ctx context.Context, localNodeID storj.NodeID, opts routingTableOpts) *RoutingTable { if localNodeID == (storj.NodeID{}) { panic("empty local node id") } local := &overlay.NodeDossier{Node: pb.Node{Id: localNodeID}} - rt, err := newTestRoutingTable(local, opts) + rt, err := newTestRoutingTable(ctx, local, opts) if err != nil { panic(err) } return rt } -func createRoutingTable(localNodeID storj.NodeID) *RoutingTable { - return createRoutingTableWith(localNodeID, routingTableOpts{}) +func createRoutingTable(ctx context.Context, localNodeID storj.NodeID) *RoutingTable { + return createRoutingTableWith(ctx, localNodeID, routingTableOpts{}) } func TestAddNode(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("OO")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("OO")) defer ctx.Check(rt.Close) cases := []struct { @@ -204,14 +205,14 @@ func TestAddNode(t *testing.T) { for _, c := range cases { testCase := c t.Run(testCase.testID, func(t *testing.T) { - ok, err := rt.addNode(testCase.node) + ok, err := rt.addNode(ctx, testCase.node) require.NoError(t, err) require.Equal(t, testCase.added, ok) kadKeys, err := rt.kadBucketDB.List(ctx, nil, 0) require.NoError(t, err) for i, v := range kadKeys { require.True(t, bytes.Equal(testCase.kadIDs[i], v[:2])) - ids, err := rt.getNodeIDsWithinKBucket(keyToBucketID(v)) + ids, err := rt.getNodeIDsWithinKBucket(ctx, keyToBucketID(v)) require.NoError(t, err) require.True(t, len(ids) == len(testCase.nodeIDs[i])) for j, id := range ids { @@ -232,10 +233,10 @@ func TestAddNode(t *testing.T) { func TestUpdateNode(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) node := teststorj.MockNode("BB") - ok, err := rt.addNode(node) + ok, err := rt.addNode(ctx, node) assert.True(t, ok) assert.NoError(t, err) val, err := rt.nodeBucketDB.Get(ctx, node.Id.Bytes()) @@ -246,7 +247,7 @@ func TestUpdateNode(t *testing.T) { assert.Nil(t, x) node.Address = &pb.NodeAddress{Address: "BB"} - err = rt.updateNode(node) + err = rt.updateNode(ctx, node) assert.NoError(t, err) val, err = rt.nodeBucketDB.Get(ctx, node.Id.Bytes()) assert.NoError(t, err) @@ -259,11 +260,11 @@ func TestUpdateNode(t *testing.T) { func TestRemoveNode(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) kadBucketID := firstBucketID node := teststorj.MockNode("BB") - ok, err := rt.addNode(node) + ok, err := rt.addNode(ctx, node) assert.True(t, ok) assert.NoError(t, err) val, err := rt.nodeBucketDB.Get(ctx, node.Id.Bytes()) @@ -271,7 +272,7 @@ func TestRemoveNode(t *testing.T) { assert.NotNil(t, val) node2 := teststorj.MockNode("CC") rt.addToReplacementCache(kadBucketID, node2) - err = rt.removeNode(node) + err = rt.removeNode(ctx, node) assert.NoError(t, err) val, err = rt.nodeBucketDB.Get(ctx, node.Id.Bytes()) assert.Nil(t, val) @@ -282,7 +283,7 @@ func TestRemoveNode(t *testing.T) { assert.Equal(t, 0, len(rt.replacementCache[kadBucketID])) //try to remove node not in rt - err = rt.removeNode(&pb.Node{ + err = rt.removeNode(ctx, &pb.Node{ Id: teststorj.NodeIDFromString("DD"), Address: &pb.NodeAddress{Address: "address:1"}, }) @@ -293,7 +294,7 @@ func TestCreateOrUpdateKBucket(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() id := bucketID{255, 255} - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) err := rt.createOrUpdateKBucket(ctx, id, time.Now()) assert.NoError(t, err) @@ -308,7 +309,7 @@ func TestGetKBucketID(t *testing.T) { defer ctx.Cleanup() kadIDA := bucketID{255, 255} nodeIDA := teststorj.NodeIDFromString("AA") - rt := createRoutingTable(nodeIDA) + rt := createRoutingTable(ctx, nodeIDA) defer ctx.Check(rt.Close) keyA, err := rt.getKBucketID(ctx, nodeIDA) assert.NoError(t, err) @@ -318,7 +319,7 @@ func TestGetKBucketID(t *testing.T) { func TestWouldBeInNearestK(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTableWith(storj.NodeID{127, 255}, routingTableOpts{bucketSize: 2}) + rt := createRoutingTableWith(ctx, storj.NodeID{127, 255}, routingTableOpts{bucketSize: 2}) defer ctx.Check(rt.Close) cases := []struct { @@ -350,7 +351,7 @@ func TestWouldBeInNearestK(t *testing.T) { for _, c := range cases { testCase := c t.Run(testCase.testID, func(t *testing.T) { - result, err := rt.wouldBeInNearestK(testCase.nodeID) + result, err := rt.wouldBeInNearestK(ctx, testCase.nodeID) assert.NoError(t, err) assert.Equal(t, testCase.closest, result) assert.NoError(t, rt.nodeBucketDB.Put(ctx, testCase.nodeID.Bytes(), []byte(""))) @@ -362,7 +363,7 @@ func TestKadBucketContainsLocalNode(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111] - rt := createRoutingTable(nodeIDA) + rt := createRoutingTable(ctx, nodeIDA) defer ctx.Check(rt.Close) kadIDA := firstBucketID var kadIDB bucketID @@ -371,9 +372,9 @@ func TestKadBucketContainsLocalNode(t *testing.T) { now := time.Now() err := rt.createOrUpdateKBucket(ctx, kadIDB, now) assert.NoError(t, err) - resultTrue, err := rt.kadBucketContainsLocalNode(kadIDA) + resultTrue, err := rt.kadBucketContainsLocalNode(ctx, kadIDA) assert.NoError(t, err) - resultFalse, err := rt.kadBucketContainsLocalNode(kadIDB) + resultFalse, err := rt.kadBucketContainsLocalNode(ctx, kadIDB) assert.NoError(t, err) assert.True(t, resultTrue) assert.False(t, resultFalse) @@ -383,7 +384,7 @@ func TestKadBucketHasRoom(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() node1 := storj.NodeID{255, 255} - rt := createRoutingTable(node1) + rt := createRoutingTable(ctx, node1) defer ctx.Check(rt.Close) kadIDA := firstBucketID node2 := storj.NodeID{191, 255} @@ -391,7 +392,7 @@ func TestKadBucketHasRoom(t *testing.T) { node4 := storj.NodeID{63, 255} node5 := storj.NodeID{159, 255} node6 := storj.NodeID{0, 127} - resultA, err := rt.kadBucketHasRoom(kadIDA) + resultA, err := rt.kadBucketHasRoom(ctx, kadIDA) assert.NoError(t, err) assert.True(t, resultA) assert.NoError(t, rt.nodeBucketDB.Put(ctx, node2.Bytes(), []byte(""))) @@ -399,7 +400,7 @@ func TestKadBucketHasRoom(t *testing.T) { assert.NoError(t, rt.nodeBucketDB.Put(ctx, node4.Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(ctx, node5.Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(ctx, node6.Bytes(), []byte(""))) - resultB, err := rt.kadBucketHasRoom(kadIDA) + resultB, err := rt.kadBucketHasRoom(ctx, kadIDA) assert.NoError(t, err) assert.False(t, resultB) } @@ -408,7 +409,7 @@ func TestGetNodeIDsWithinKBucket(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111] - rt := createRoutingTable(nodeIDA) + rt := createRoutingTable(ctx, nodeIDA) defer ctx.Check(rt.Close) kadIDA := firstBucketID var kadIDB bucketID @@ -440,7 +441,7 @@ func TestGetNodeIDsWithinKBucket(t *testing.T) { for _, c := range cases { testCase := c t.Run(testCase.testID, func(t *testing.T) { - n, err := rt.getNodeIDsWithinKBucket(testCase.kadID) + n, err := rt.getNodeIDsWithinKBucket(ctx, testCase.kadID) assert.NoError(t, err) for i, id := range testCase.expected { assert.True(t, id.Equal(n[i].Bytes())) @@ -461,7 +462,7 @@ func TestGetNodesFromIDs(t *testing.T) { assert.NoError(t, err) c, err := proto.Marshal(nodeC) assert.NoError(t, err) - rt := createRoutingTable(nodeA.Id) + rt := createRoutingTable(ctx, nodeA.Id) defer ctx.Check(rt.Close) assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeA.Id.Bytes(), a)) @@ -471,7 +472,7 @@ func TestGetNodesFromIDs(t *testing.T) { nodeKeys, err := rt.nodeBucketDB.List(ctx, nil, 0) assert.NoError(t, err) - values, err := rt.getNodesFromIDsBytes(teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...)) + values, err := rt.getNodesFromIDsBytes(ctx, teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...)) assert.NoError(t, err) for i, n := range expected { assert.True(t, bytes.Equal(n.Id.Bytes(), values[i].Id.Bytes())) @@ -491,14 +492,14 @@ func TestUnmarshalNodes(t *testing.T) { assert.NoError(t, err) c, err := proto.Marshal(nodeC) assert.NoError(t, err) - rt := createRoutingTable(nodeA.Id) + rt := createRoutingTable(ctx, nodeA.Id) defer ctx.Check(rt.Close) assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeA.Id.Bytes(), a)) assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeB.Id.Bytes(), b)) assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeC.Id.Bytes(), c)) nodeKeys, err := rt.nodeBucketDB.List(ctx, nil, 0) assert.NoError(t, err) - nodes, err := rt.getNodesFromIDsBytes(teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...)) + nodes, err := rt.getNodesFromIDsBytes(ctx, teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...)) assert.NoError(t, err) expected := []*pb.Node{nodeA, nodeB, nodeC} for i, v := range expected { @@ -510,17 +511,17 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() nodeA := teststorj.MockNode("AA") - rt := createRoutingTable(nodeA.Id) + rt := createRoutingTable(ctx, nodeA.Id) defer ctx.Check(rt.Close) bucketID := firstBucketID nodeB := teststorj.MockNode("BB") nodeC := teststorj.MockNode("CC") var err error - _, err = rt.addNode(nodeB) + _, err = rt.addNode(ctx, nodeB) assert.NoError(t, err) - _, err = rt.addNode(nodeC) + _, err = rt.addNode(ctx, nodeC) assert.NoError(t, err) - nodes, err := rt.getUnmarshaledNodesFromBucket(bucketID) + nodes, err := rt.getUnmarshaledNodesFromBucket(ctx, bucketID) expected := []*pb.Node{nodeA, nodeB, nodeC} assert.NoError(t, err) for i, v := range expected { @@ -531,7 +532,7 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) { func TestGetKBucketRange(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) idA := storj.NodeID{255, 255} idB := storj.NodeID{127, 255} @@ -560,7 +561,7 @@ func TestGetKBucketRange(t *testing.T) { for _, c := range cases { testCase := c t.Run(testCase.testID, func(t *testing.T) { - ep, err := rt.getKBucketRange(keyToBucketID(testCase.id.Bytes())) + ep, err := rt.getKBucketRange(ctx, keyToBucketID(testCase.id.Bytes())) assert.NoError(t, err) for i, k := range testCase.expected { assert.True(t, k.Equal(ep[i][:])) @@ -578,7 +579,7 @@ func TestBucketIDZeroValue(t *testing.T) { func TestDetermineLeafDepth(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) idA, idB, idC := firstBucketID, firstBucketID, firstBucketID idA[0] = 255 @@ -630,7 +631,7 @@ func TestDetermineLeafDepth(t *testing.T) { testCase := c t.Run(testCase.testID, func(t *testing.T) { testCase.addNode() - d, err := rt.determineLeafDepth(testCase.id) + d, err := rt.determineLeafDepth(ctx, testCase.id) assert.NoError(t, err) assert.Equal(t, testCase.depth, d) }) @@ -640,7 +641,7 @@ func TestDetermineLeafDepth(t *testing.T) { func TestSplitBucket(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) cases := []struct { testID string diff --git a/pkg/kademlia/routing_integration_test.go b/pkg/kademlia/routing_integration_test.go index 629991379..6b5e278f4 100644 --- a/pkg/kademlia/routing_integration_test.go +++ b/pkg/kademlia/routing_integration_test.go @@ -4,6 +4,7 @@ package kademlia import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -15,19 +16,19 @@ import ( "storj.io/storj/pkg/storj" ) -type routingCtor func(storj.NodeID, int, int, int) dht.RoutingTable +type routingCtor func(context.Context, storj.NodeID, int, int, int) dht.RoutingTable -func newRouting(self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable { +func newRouting(ctx context.Context, self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable { if allowedFailures != 0 { panic("failure counting currently unsupported") } - return createRoutingTableWith(self, routingTableOpts{ + return createRoutingTableWith(ctx, self, routingTableOpts{ bucketSize: bucketSize, cacheSize: cacheSize, }) } -func newTestRouting(self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable { +func newTestRouting(ctx context.Context, self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable { return testrouting.New(self, bucketSize, cacheSize, allowedFailures) } @@ -39,12 +40,12 @@ func testTableInit(t *testing.T, routingCtor routingCtor) { bucketSize := 5 cacheSize := 3 - table := routingCtor(PadID("55", "5"), bucketSize, cacheSize, 0) + table := routingCtor(ctx, PadID("55", "5"), bucketSize, cacheSize, 0) defer ctx.Check(table.Close) require.Equal(t, bucketSize, table.K()) require.Equal(t, cacheSize, table.CacheSize()) - nodes, err := table.FindNear(PadID("21", "0"), 3) + nodes, err := table.FindNear(ctx, PadID("21", "0"), 3) require.NoError(t, err) require.Equal(t, 0, len(nodes)) } @@ -55,13 +56,13 @@ func testTableBasic(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("5555", "5"), 5, 3, 0) + table := routingCtor(ctx, PadID("5555", "5"), 5, 3, 0) defer ctx.Check(table.Close) - err := table.ConnectionSuccess(Node(PadID("5556", "5"), "address:1")) + err := table.ConnectionSuccess(ctx, Node(PadID("5556", "5"), "address:1")) require.NoError(t, err) - nodes, err := table.FindNear(PadID("21", "0"), 3) + nodes, err := table.FindNear(ctx, PadID("21", "0"), 3) require.NoError(t, err) require.Equal(t, 1, len(nodes)) require.Equal(t, PadID("5556", "5"), nodes[0].Id) @@ -74,12 +75,12 @@ func testNoSelf(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("55", "5"), 5, 3, 0) + table := routingCtor(ctx, PadID("55", "5"), 5, 3, 0) defer ctx.Check(table.Close) - err := table.ConnectionSuccess(Node(PadID("55", "5"), "address:2")) + err := table.ConnectionSuccess(ctx, Node(PadID("55", "5"), "address:2")) require.NoError(t, err) - nodes, err := table.FindNear(PadID("21", "0"), 3) + nodes, err := table.FindNear(ctx, PadID("21", "0"), 3) require.NoError(t, err) require.Equal(t, 0, len(nodes)) } @@ -90,12 +91,12 @@ func testSplits(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("55", "5"), 5, 2, 0) + table := routingCtor(ctx, PadID("55", "5"), 5, 2, 0) defer ctx.Check(table.Close) for _, prefix2 := range "18" { for _, prefix1 := range "a69c23f1d7eb5408" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "0"))) } } @@ -108,7 +109,7 @@ func testSplits(t *testing.T, routingCtor routingCtor) { // three bits should also not be full and have 4 nodes // (40..., 48..., 50..., 58...). So we should be able to get no more than // 18 nodes back - nodes, err := table.FindNear(PadID("55", "5"), 19) + nodes, err := table.FindNear(ctx, PadID("55", "5"), 19) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ // bucket 010 (same first three bits) @@ -140,23 +141,23 @@ func testSplits(t *testing.T, routingCtor routingCtor) { // the gaps // bucket 010 shouldn't have anything in its replacement cache - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0"))) // bucket 011 shouldn't have anything in its replacement cache - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("68", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("68", "0"))) // bucket 00 should have two things in its replacement cache, 18... is one of them - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("18", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("18", "0"))) // now just one thing in its replacement cache - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("31", "0"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("28", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("31", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("28", "0"))) // bucket 1 should have two things in its replacement cache - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("a1", "0"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "0"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("91", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("a1", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("91", "0"))) - nodes, err = table.FindNear(PadID("55", "5"), 19) + nodes, err = table.FindNear(ctx, PadID("55", "5"), 19) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ // bucket 010 @@ -187,12 +188,12 @@ func testUnbalanced(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("ff", "f"), 5, 2, 0) + table := routingCtor(ctx, PadID("ff", "f"), 5, 2, 0) defer ctx.Check(table.Close) for _, prefix1 := range "0123456789abcdef" { for _, prefix2 := range "18" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "0"))) } } @@ -202,7 +203,7 @@ func testUnbalanced(t *testing.T, routingCtor routingCtor) { // would have forced every bucket to split, and we should have stored all // possible nodes. - nodes, err := table.FindNear(PadID("ff", "f"), 33) + nodes, err := table.FindNear(ctx, PadID("ff", "f"), 33) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("f8", "0"), NodeFromPrefix("f1", "0"), @@ -230,24 +231,24 @@ func testQuery(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 5, 2, 0) + table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 0) defer ctx.Check(table.Close) for _, prefix2 := range "18" { for _, prefix1 := range "b4f25c896de03a71" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "f"))) } } - nodes, err := table.FindNear(PadID("c7139", "1"), 2) + nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 2) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), NodeFromPrefix("d1", "f"), }, nodes) - nodes, err = table.FindNear(PadID("c7139", "1"), 7) + nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 7) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), @@ -259,7 +260,7 @@ func testQuery(t *testing.T, routingCtor routingCtor) { NodeFromPrefix("88", "f"), }, nodes) - nodes, err = table.FindNear(PadID("c7139", "1"), 10) + nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 10) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), @@ -281,18 +282,18 @@ func testFailureCounting(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 5, 2, 2) + table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 2) defer ctx.Check(table.Close) for _, prefix2 := range "18" { for _, prefix1 := range "b4f25c896de03a71" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "f"))) } } nochange := func() { - nodes, err := table.FindNear(PadID("c7139", "1"), 7) + nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 7) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), @@ -306,13 +307,13 @@ func testFailureCounting(t *testing.T, routingCtor routingCtor) { } nochange() - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f"))) nochange() - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f"))) nochange() - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f"))) - nodes, err := table.FindNear(PadID("c7139", "1"), 7) + nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 7) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), @@ -331,26 +332,26 @@ func testUpdateBucket(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 5, 2, 0) + table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 0) defer ctx.Check(table.Close) for _, prefix2 := range "18" { for _, prefix1 := range "b4f25c896de03a71" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "f"))) } } - nodes, err := table.FindNear(PadID("c7139", "1"), 1) + nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 1) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("c1", "f"), }, nodes) - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, Node(PadID("c1", "f"), "new-address:3"))) - nodes, err = table.FindNear(PadID("c7139", "1"), 1) + nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 1) require.NoError(t, err) require.Equal(t, 1, len(nodes)) require.Equal(t, PadID("c1", "f"), nodes[0].Id) @@ -363,18 +364,18 @@ func testUpdateCache(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 1, 1, 0) + table := routingCtor(ctx, PadID("a3", "3"), 1, 1, 0) defer ctx.Check(table.Close) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("41", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("41", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0"))) - require.NoError(t, table.ConnectionSuccess(Node(PadID("01", "0"), "new-address:6"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, Node(PadID("01", "0"), "new-address:6"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0"))) - nodes, err := table.FindNear(PadID("01", "0"), 4) + nodes, err := table.FindNear(ctx, PadID("01", "0"), 4) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ @@ -390,16 +391,16 @@ func testFailureUnknownAddress(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 1, 1, 0) + table := routingCtor(ctx, PadID("a3", "3"), 1, 1, 0) defer ctx.Check(table.Close) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0"))) - require.NoError(t, table.ConnectionSuccess(Node(PadID("41", "0"), "address:2"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, Node(PadID("41", "0"), "address:2"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0"))) - nodes, err := table.FindNear(PadID("01", "0"), 4) + nodes, err := table.FindNear(ctx, PadID("01", "0"), 4) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ @@ -415,13 +416,13 @@ func testShrink(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("ff", "f"), 2, 2, 0) + table := routingCtor(ctx, PadID("ff", "f"), 2, 2, 0) defer ctx.Check(table.Close) // blow out the routing table for _, prefix1 := range "0123456789abcdef" { for _, prefix2 := range "18" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "0"))) } } @@ -429,7 +430,7 @@ func testShrink(t *testing.T, routingCtor routingCtor) { // delete some of the bad ones for _, prefix1 := range "0123456789abcd" { for _, prefix2 := range "18" { - require.NoError(t, table.ConnectionFailed( + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "0"))) } } @@ -437,13 +438,13 @@ func testShrink(t *testing.T, routingCtor routingCtor) { // add back some nodes more balanced for _, prefix1 := range "3a50" { for _, prefix2 := range "19" { - require.NoError(t, table.ConnectionSuccess( + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(string([]rune{prefix1, prefix2}), "0"))) } } // make sure table filled in alright - nodes, err := table.FindNear(PadID("ff", "f"), 13) + nodes, err := table.FindNear(ctx, PadID("ff", "f"), 13) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("f8", "0"), @@ -467,17 +468,17 @@ func testReplacementCacheOrder(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("a3", "3"), 1, 2, 0) + table := routingCtor(ctx, PadID("a3", "3"), 1, 2, 0) defer ctx.Check(table.Close) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("21", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("41", "0"))) - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0"))) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("21", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("21", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("41", "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("21", "0"))) - nodes, err := table.FindNear(PadID("55", "5"), 4) + nodes, err := table.FindNear(ctx, PadID("55", "5"), 4) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ @@ -493,16 +494,16 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("55", "55"), 2, 2, 0) + table := routingCtor(ctx, PadID("55", "55"), 2, 2, 0) defer ctx.Check(table.Close) for _, pad := range []string{"0", "1"} { for _, prefix := range []string{"ff", "e1", "c1", "54", "56", "57"} { - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, pad))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, pad))) } } - nodes, err := table.FindNear(PadID("55", "55"), 9) + nodes, err := table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -515,9 +516,9 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { NodeFromPrefix("e1", "0"), }, nodes) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("c1", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("c1", "0"))) - nodes, err = table.FindNear(PadID("55", "55"), 9) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -529,8 +530,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { NodeFromPrefix("e1", "0"), }, nodes) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("ff", "0"))) - nodes, err = table.FindNear(PadID("55", "55"), 9) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("ff", "0"))) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -542,8 +543,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { NodeFromPrefix("e1", "0"), }, nodes) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("e1", "0"))) - nodes, err = table.FindNear(PadID("55", "55"), 9) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("e1", "0"))) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -555,8 +556,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { NodeFromPrefix("e1", "1"), }, nodes) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("e1", "1"))) - nodes, err = table.FindNear(PadID("55", "55"), 9) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("e1", "1"))) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -568,10 +569,10 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) { }, nodes) for _, prefix := range []string{"ff", "e1", "c1", "54", "56", "57"} { - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, "2"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, "2"))) } - nodes, err = table.FindNear(PadID("55", "55"), 9) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("54", "1"), @@ -591,23 +592,23 @@ func testFullDissimilarBucket(t *testing.T, routingCtor routingCtor) { ctx := testcontext.New(t) defer ctx.Cleanup() - table := routingCtor(PadID("55", "55"), 2, 2, 0) + table := routingCtor(ctx, PadID("55", "55"), 2, 2, 0) defer ctx.Check(table.Close) for _, prefix := range []string{"d1", "c1", "f1", "e1"} { - require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, "0"))) + require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, "0"))) } - nodes, err := table.FindNear(PadID("55", "55"), 9) + nodes, err := table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("d1", "0"), NodeFromPrefix("c1", "0"), }, nodes) - require.NoError(t, table.ConnectionFailed(NodeFromPrefix("c1", "0"))) + require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("c1", "0"))) - nodes, err = table.FindNear(PadID("55", "55"), 9) + nodes, err = table.FindNear(ctx, PadID("55", "55"), 9) require.NoError(t, err) requireNodesEqual(t, []*pb.Node{ NodeFromPrefix("d1", "0"), diff --git a/pkg/kademlia/routing_test.go b/pkg/kademlia/routing_test.go index cc87c4420..64ca5b569 100644 --- a/pkg/kademlia/routing_test.go +++ b/pkg/kademlia/routing_test.go @@ -5,6 +5,7 @@ package kademlia import ( "bytes" + "context" "fmt" "math/rand" "sort" @@ -25,7 +26,7 @@ func TestLocal(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) assert.Equal(t, rt.Local().Id.Bytes()[:2], []byte("AA")) } @@ -34,7 +35,7 @@ func TestK(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) k := rt.K() assert.Equal(t, rt.bucketSize, k) @@ -45,7 +46,7 @@ func TestCacheSize(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) expected := rt.rcBucketSize result := rt.CacheSize() @@ -56,11 +57,11 @@ func TestGetBucket(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - rt := createRoutingTable(teststorj.NodeIDFromString("AA")) + rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA")) defer ctx.Check(rt.Close) node := teststorj.MockNode("AA") node2 := teststorj.MockNode("BB") - ok, err := rt.addNode(node2) + ok, err := rt.addNode(ctx, node2) assert.True(t, ok) assert.NoError(t, err) @@ -79,7 +80,7 @@ func TestGetBucket(t *testing.T) { }, } for i, v := range cases { - b, e := rt.GetNodes(node2.Id) + b, e := rt.GetNodes(ctx, node2.Id) for j, w := range v.expected { if !assert.True(t, bytes.Equal(w.Id.Bytes(), b[j].Id.Bytes())) { t.Logf("case %v failed expected: ", i) @@ -97,14 +98,15 @@ func RandomNode() pb.Node { return node } func TestKademliaFindNear(t *testing.T) { + ctx := context.Background() testFunc := func(t *testing.T, testNodeCount, limit int) { selfNode := RandomNode() - rt := createRoutingTable(selfNode.Id) + rt := createRoutingTable(ctx, selfNode.Id) expectedIDs := make([]storj.NodeID, 0) for x := 0; x < testNodeCount; x++ { n := RandomNode() - ok, err := rt.addNode(&n) + ok, err := rt.addNode(ctx, &n) require.NoError(t, err) if ok { // buckets were full expectedIDs = append(expectedIDs, n.Id) @@ -118,7 +120,7 @@ func TestKademliaFindNear(t *testing.T) { targetNode.Id[storj.NodeIDSize-1] ^= 1 //flip lowest bit sortByXOR(expectedIDs, targetNode.Id) - results, err := rt.FindNear(targetNode.Id, limit) + results, err := rt.FindNear(ctx, targetNode.Id, limit) require.NoError(t, err) counts := []int{len(expectedIDs), limit} sort.Ints(counts) @@ -142,7 +144,7 @@ func TestConnectionSuccess(t *testing.T) { defer ctx.Cleanup() id := teststorj.NodeIDFromString("AA") - rt := createRoutingTable(id) + rt := createRoutingTable(ctx, id) defer ctx.Check(rt.Close) id2 := teststorj.NodeIDFromString("BB") address1 := &pb.NodeAddress{Address: "a"} @@ -169,7 +171,7 @@ func TestConnectionSuccess(t *testing.T) { for _, c := range cases { testCase := c t.Run(testCase.testID, func(t *testing.T) { - err := rt.ConnectionSuccess(testCase.node) + err := rt.ConnectionSuccess(ctx, testCase.node) assert.NoError(t, err) v, err := rt.nodeBucketDB.Get(ctx, testCase.id.Bytes()) assert.NoError(t, err) @@ -186,9 +188,9 @@ func TestConnectionFailed(t *testing.T) { id := teststorj.NodeIDFromString("AA") node := &pb.Node{Id: id} - rt := createRoutingTable(id) + rt := createRoutingTable(ctx, id) defer ctx.Check(rt.Close) - err := rt.ConnectionFailed(node) + err := rt.ConnectionFailed(ctx, node) assert.NoError(t, err) v, err := rt.nodeBucketDB.Get(ctx, id.Bytes()) assert.Error(t, err) @@ -200,19 +202,19 @@ func TestSetBucketTimestamp(t *testing.T) { defer ctx.Cleanup() id := teststorj.NodeIDFromString("AA") - rt := createRoutingTable(id) + rt := createRoutingTable(ctx, id) defer ctx.Check(rt.Close) now := time.Now().UTC() err := rt.createOrUpdateKBucket(ctx, keyToBucketID(id.Bytes()), now) assert.NoError(t, err) - ti, err := rt.GetBucketTimestamp(id.Bytes()) + ti, err := rt.GetBucketTimestamp(ctx, id.Bytes()) assert.Equal(t, now, ti) assert.NoError(t, err) now = time.Now().UTC() - err = rt.SetBucketTimestamp(id.Bytes(), now) + err = rt.SetBucketTimestamp(ctx, id.Bytes(), now) assert.NoError(t, err) - ti, err = rt.GetBucketTimestamp(id.Bytes()) + ti, err = rt.GetBucketTimestamp(ctx, id.Bytes()) assert.Equal(t, now, ti) assert.NoError(t, err) } @@ -222,12 +224,12 @@ func TestGetBucketTimestamp(t *testing.T) { defer ctx.Cleanup() id := teststorj.NodeIDFromString("AA") - rt := createRoutingTable(id) + rt := createRoutingTable(ctx, id) defer ctx.Check(rt.Close) now := time.Now().UTC() err := rt.createOrUpdateKBucket(ctx, keyToBucketID(id.Bytes()), now) assert.NoError(t, err) - ti, err := rt.GetBucketTimestamp(id.Bytes()) + ti, err := rt.GetBucketTimestamp(ctx, id.Bytes()) assert.Equal(t, now, ti) assert.NoError(t, err) } diff --git a/pkg/kademlia/testrouting/testrouting.go b/pkg/kademlia/testrouting/testrouting.go index 9d5b997ba..75e3804d9 100644 --- a/pkg/kademlia/testrouting/testrouting.go +++ b/pkg/kademlia/testrouting/testrouting.go @@ -4,10 +4,13 @@ package testrouting import ( + "context" "sort" "sync" "time" + monkit "gopkg.in/spacemonkeygo/monkit.v2" + "storj.io/storj/pkg/dht" "storj.io/storj/pkg/overlay" "storj.io/storj/pkg/pb" @@ -15,6 +18,10 @@ import ( "storj.io/storj/storage" ) +var ( + mon = monkit.Package() +) + type nodeData struct { node *pb.Node ordering int64 @@ -63,7 +70,8 @@ func (t *Table) CacheSize() int { return t.cacheSize } // ConnectionSuccess should be called whenever a node is successfully connected // to. It will add or update the node's entry in the routing table. -func (t *Table) ConnectionSuccess(node *pb.Node) error { +func (t *Table) ConnectionSuccess(ctx context.Context, node *pb.Node) (err error) { + defer mon.Task()(&ctx)(&err) t.mu.Lock() defer t.mu.Unlock() @@ -100,7 +108,8 @@ func (t *Table) ConnectionSuccess(node *pb.Node) error { // ConnectionFailed should be called whenever a node can't be contacted. // If a node fails more than allowedFailures times, it will be removed from // the routing table. The failure count is reset every successful connection. -func (t *Table) ConnectionFailed(node *pb.Node) error { +func (t *Table) ConnectionFailed(ctx context.Context, node *pb.Node) (err error) { + defer mon.Task()(&ctx)(&err) t.mu.Lock() defer t.mu.Unlock() @@ -122,7 +131,8 @@ func (t *Table) ConnectionFailed(node *pb.Node) error { // FindNear will return up to limit nodes in the routing table ordered by // kademlia xor distance from the given id. -func (t *Table) FindNear(id storj.NodeID, limit int) ([]*pb.Node, error) { +func (t *Table) FindNear(ctx context.Context, id storj.NodeID, limit int) (_ []*pb.Node, err error) { + defer mon.Task()(&ctx)(&err) t.mu.Lock() defer t.mu.Unlock() @@ -180,17 +190,17 @@ func (t *Table) GetNodes(id storj.NodeID) (nodes []*pb.Node, ok bool) { } // GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance -func (t *Table) GetBucketIds() (storage.Keys, error) { +func (t *Table) GetBucketIds(context.Context) (storage.Keys, error) { panic("TODO") } // SetBucketTimestamp records the time of the last node lookup for a bucket -func (t *Table) SetBucketTimestamp(id []byte, now time.Time) error { +func (t *Table) SetBucketTimestamp(context.Context, []byte, time.Time) error { panic("TODO") } // GetBucketTimestamp retrieves time of the last node lookup for a bucket -func (t *Table) GetBucketTimestamp(id []byte) (time.Time, error) { +func (t *Table) GetBucketTimestamp(context.Context, []byte) (time.Time, error) { panic("TODO") } diff --git a/pkg/kademlia/testrouting/utils.go b/pkg/kademlia/testrouting/utils.go index 1b5b8d9d6..4f0d1890b 100644 --- a/pkg/kademlia/testrouting/utils.go +++ b/pkg/kademlia/testrouting/utils.go @@ -3,9 +3,7 @@ package testrouting -import ( - "storj.io/storj/pkg/storj" -) +import "storj.io/storj/pkg/storj" type nodeDataDistanceSorter struct { self storj.NodeID