kademlia/routing: add contexts to more places so monkit works (#2188)

This commit is contained in:
JT Olio 2019-06-13 08:51:50 -06:00 committed by Stefan Benten
parent ec63204ab1
commit 1ae5654eba
13 changed files with 254 additions and 253 deletions

View File

@ -28,13 +28,13 @@ type RoutingTable interface {
Local() overlay.NodeDossier
K() int
CacheSize() int
GetBucketIds() (storage.Keys, error)
FindNear(id storj.NodeID, limit int) ([]*pb.Node, error)
ConnectionSuccess(node *pb.Node) error
ConnectionFailed(node *pb.Node) error
GetBucketIds(context.Context) (storage.Keys, error)
FindNear(ctx context.Context, id storj.NodeID, limit int) ([]*pb.Node, error)
ConnectionSuccess(ctx context.Context, node *pb.Node) error
ConnectionFailed(ctx context.Context, node *pb.Node) error
// these are for refreshing
SetBucketTimestamp(id []byte, now time.Time) error
GetBucketTimestamp(id []byte) (time.Time, error)
SetBucketTimestamp(ctx context.Context, id []byte, now time.Time) error
GetBucketTimestamp(ctx context.Context, id []byte) (time.Time, error)
Close() error
}

View File

@ -42,7 +42,7 @@ func (endpoint *Endpoint) Query(ctx context.Context, req *pb.QueryRequest) (_ *p
endpoint.pingback(ctx, req.Sender)
}
nodes, err := endpoint.routingTable.FindNear(req.Target.Id, int(req.Limit))
nodes, err := endpoint.routingTable.FindNear(ctx, req.Target.Id, int(req.Limit))
if err != nil {
return &pb.QueryResponse{}, EndpointError.New("could not find near endpoint: %v", err)
}
@ -57,12 +57,12 @@ func (endpoint *Endpoint) pingback(ctx context.Context, target *pb.Node) {
_, err = endpoint.service.Ping(ctx, *target)
if err != nil {
endpoint.log.Debug("connection to node failed", zap.Error(err), zap.String("nodeID", target.Id.String()))
err = endpoint.routingTable.ConnectionFailed(target)
err = endpoint.routingTable.ConnectionFailed(ctx, target)
if err != nil {
endpoint.log.Error("could not respond to connection failed", zap.Error(err))
}
} else {
err = endpoint.routingTable.ConnectionSuccess(target)
err = endpoint.routingTable.ConnectionSuccess(ctx, target)
if err != nil {
endpoint.log.Error("could not respond to connection success", zap.Error(err))
} else {

View File

@ -43,7 +43,7 @@ func (srv *Inspector) CountNodes(ctx context.Context, req *pb.CountNodesRequest)
// GetBuckets returns all kademlia buckets for current kademlia instance
func (srv *Inspector) GetBuckets(ctx context.Context, req *pb.GetBucketsRequest) (_ *pb.GetBucketsResponse, err error) {
defer mon.Task()(&ctx)(&err)
b, err := srv.dht.GetBucketIds()
b, err := srv.dht.GetBucketIds(ctx)
if err != nil {
return nil, err
}
@ -142,7 +142,7 @@ func (srv *Inspector) NodeInfo(ctx context.Context, req *pb.NodeInfoRequest) (_
// GetBucketList returns the list of buckets with their routing nodes and their cached nodes
func (srv *Inspector) GetBucketList(ctx context.Context, req *pb.GetBucketListRequest) (_ *pb.GetBucketListResponse, err error) {
defer mon.Task()(&ctx)(&err)
bucketIds, err := srv.dht.GetBucketIds()
bucketIds, err := srv.dht.GetBucketIds(ctx)
if err != nil {
return nil, err
}
@ -151,7 +151,7 @@ func (srv *Inspector) GetBucketList(ctx context.Context, req *pb.GetBucketListRe
for i, b := range bucketIds {
bucketID := keyToBucketID(b)
routingNodes, err := srv.dht.GetNodesWithinKBucket(bucketID)
routingNodes, err := srv.dht.GetNodesWithinKBucket(ctx, bucketID)
if err != nil {
return nil, err
}

View File

@ -120,12 +120,13 @@ func (k *Kademlia) Queried() {
// stored in the local routing table.
func (k *Kademlia) FindNear(ctx context.Context, start storj.NodeID, limit int) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
return k.routingTable.FindNear(start, limit)
return k.routingTable.FindNear(ctx, start, limit)
}
// GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance
func (k *Kademlia) GetBucketIds() (storage.Keys, error) {
return k.routingTable.GetBucketIds()
func (k *Kademlia) GetBucketIds(ctx context.Context) (_ storage.Keys, err error) {
defer mon.Task()(&ctx)(&err)
return k.routingTable.GetBucketIds(ctx)
}
// Local returns the local node
@ -141,8 +142,9 @@ func (k *Kademlia) SetBootstrapNodes(nodes []pb.Node) { k.bootstrapNodes = nodes
func (k *Kademlia) GetBootstrapNodes() []pb.Node { return k.bootstrapNodes }
// DumpNodes returns all the nodes in the node database
func (k *Kademlia) DumpNodes(ctx context.Context) ([]*pb.Node, error) {
return k.routingTable.DumpNodes()
func (k *Kademlia) DumpNodes(ctx context.Context) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
return k.routingTable.DumpNodes(ctx)
}
// Bootstrap contacts one of a set of pre defined trusted nodes on the network and
@ -298,7 +300,7 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap
}
} else {
var err error
nodes, err = k.routingTable.FindNear(nodeID, kb)
nodes, err = k.routingTable.FindNear(ctx, nodeID, kb)
if err != nil {
return pb.Node{}, err
}
@ -314,7 +316,7 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap
if err != nil {
k.log.Warn("Error getting getKBucketID in kad lookup")
} else {
err = k.routingTable.SetBucketTimestamp(bucket[:], time.Now())
err = k.routingTable.SetBucketTimestamp(ctx, bucket[:], time.Now())
if err != nil {
k.log.Warn("Error updating bucket timestamp in kad lookup")
}
@ -329,8 +331,8 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap
}
// GetNodesWithinKBucket returns all the routing nodes in the specified k-bucket
func (k *Kademlia) GetNodesWithinKBucket(bID bucketID) ([]*pb.Node, error) {
return k.routingTable.getUnmarshaledNodesFromBucket(bID)
func (k *Kademlia) GetNodesWithinKBucket(ctx context.Context, bID bucketID) (_ []*pb.Node, err error) {
return k.routingTable.getUnmarshaledNodesFromBucket(ctx, bID)
}
// GetCachedNodesWithinKBucket returns all the cached nodes in the specified k-bucket
@ -366,7 +368,7 @@ func (k *Kademlia) Run(ctx context.Context) (err error) {
// refresh updates each Kademlia bucket not contacted in the last hour
func (k *Kademlia) refresh(ctx context.Context, threshold time.Duration) (err error) {
defer mon.Task()(&ctx)(&err)
bIDs, err := k.routingTable.GetBucketIds()
bIDs, err := k.routingTable.GetBucketIds(ctx)
if err != nil {
return Error.Wrap(err)
}
@ -375,7 +377,7 @@ func (k *Kademlia) refresh(ctx context.Context, threshold time.Duration) (err er
var errors errs.Group
for _, bID := range bIDs {
endID := keyToBucketID(bID)
ts, tErr := k.routingTable.GetBucketTimestamp(bID)
ts, tErr := k.routingTable.GetBucketTimestamp(ctx, bID)
if tErr != nil {
errors.Add(tErr)
} else if now.After(ts.Add(threshold)) {

View File

@ -66,7 +66,7 @@ func TestNewKademlia(t *testing.T) {
}
for i, v := range cases {
kad, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, v.bn, v.addr, pb.NodeOperator{}, v.id, ctx.Dir(strconv.Itoa(i)), defaultAlpha)
kad, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, v.bn, v.addr, pb.NodeOperator{}, v.id, ctx.Dir(strconv.Itoa(i)), defaultAlpha)
require.NoError(t, err)
assert.Equal(t, v.expectedErr, err)
assert.Equal(t, kad.bootstrapNodes, v.bn)
@ -93,7 +93,7 @@ func TestPeerDiscovery(t *testing.T) {
operator := pb.NodeOperator{
Wallet: "OperatorWallet",
}
k, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrapNodes, testAddress, operator, testID, ctx.Dir("test"), defaultAlpha)
k, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrapNodes, testAddress, operator, testID, ctx.Dir("test"), defaultAlpha)
require.NoError(t, err)
rt := k.routingTable
assert.Equal(t, rt.Local().Operator.Wallet, "OperatorWallet")
@ -161,7 +161,7 @@ func testNode(ctx *testcontext.Context, name string, t *testing.T, bn []pb.Node)
// new kademlia
logger := zaptest.NewLogger(t)
k, err := newKademlia(logger, pb.NodeType_STORAGE, bn, lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir(name), defaultAlpha)
k, err := newKademlia(ctx, logger, pb.NodeType_STORAGE, bn, lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir(name), defaultAlpha)
require.NoError(t, err)
s := NewEndpoint(logger, k, k.routingTable)
@ -200,18 +200,18 @@ func TestRefresh(t *testing.T) {
rt := k.routingTable
now := time.Now().UTC()
bID := firstBucketID //always exists
err := rt.SetBucketTimestamp(bID[:], now.Add(-2*time.Hour))
err := rt.SetBucketTimestamp(ctx, bID[:], now.Add(-2*time.Hour))
require.NoError(t, err)
//refresh should call FindNode, updating the time
err = k.refresh(ctx, time.Minute)
require.NoError(t, err)
ts1, err := rt.GetBucketTimestamp(bID[:])
ts1, err := rt.GetBucketTimestamp(ctx, bID[:])
require.NoError(t, err)
assert.True(t, now.Add(-5*time.Minute).Before(ts1))
//refresh should not call FindNode, leaving the previous time
err = k.refresh(ctx, time.Minute)
require.NoError(t, err)
ts2, err := rt.GetBucketTimestamp(bID[:])
ts2, err := rt.GetBucketTimestamp(ctx, bID[:])
require.NoError(t, err)
assert.True(t, ts1.Equal(ts2))
s.GracefulStop()
@ -243,7 +243,7 @@ func TestFindNear(t *testing.T) {
})
bootstrap := []pb.Node{{Id: fid2.ID, Address: &pb.NodeAddress{Address: lis.Addr().String()}}}
k, err := newKademlia(zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrap,
k, err := newKademlia(ctx, zaptest.NewLogger(t), pb.NodeType_STORAGE, bootstrap,
lis.Addr().String(), pb.NodeOperator{}, fid, ctx.Dir("kademlia"), defaultAlpha)
require.NoError(t, err)
defer ctx.Check(k.Close)
@ -254,7 +254,7 @@ func TestFindNear(t *testing.T) {
nodeID := teststorj.NodeIDFromString(id)
n := &pb.Node{Id: nodeID}
nodes = append(nodes, n)
err = k.routingTable.ConnectionSuccess(n)
err = k.routingTable.ConnectionSuccess(ctx, n)
require.NoError(t, err)
return *n
}
@ -408,7 +408,7 @@ func (mn *mockNodesServer) RequestInfo(ctx context.Context, req *pb.InfoRequest)
}
// newKademlia returns a newly configured Kademlia instance
func newKademlia(log *zap.Logger, nodeType pb.NodeType, bootstrapNodes []pb.Node, address string, operator pb.NodeOperator, identity *identity.FullIdentity, path string, alpha int) (*Kademlia, error) {
func newKademlia(ctx context.Context, log *zap.Logger, nodeType pb.NodeType, bootstrapNodes []pb.Node, address string, operator pb.NodeOperator, identity *identity.FullIdentity, path string, alpha int) (*Kademlia, error) {
self := &overlay.NodeDossier{
Node: pb.Node{
Id: identity.ID,

View File

@ -17,7 +17,7 @@ import (
func TestAddToReplacementCache(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(storj.NodeID{244, 255})
rt := createRoutingTable(ctx, storj.NodeID{244, 255})
defer ctx.Check(rt.Close)
kadBucketID := bucketID{255, 255}
@ -40,7 +40,7 @@ func TestAddToReplacementCache(t *testing.T) {
func TestRemoveFromReplacementCache(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTableWith(storj.NodeID{244, 255}, routingTableOpts{cacheSize: 3})
rt := createRoutingTableWith(ctx, storj.NodeID{244, 255}, routingTableOpts{cacheSize: 3})
defer ctx.Check(rt.Close)
kadBucketID2 := bucketID{127, 255}

View File

@ -68,7 +68,7 @@ type RoutingTable struct {
}
// NewRoutingTable returns a newly configured instance of a RoutingTable
func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, ndb storage.KeyValueStore, config *RoutingTableConfig) (*RoutingTable, error) {
func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, ndb storage.KeyValueStore, config *RoutingTableConfig) (_ *RoutingTable, err error) {
if config == nil || config.BucketSize == 0 || config.ReplacementCacheSize == 0 {
// TODO: handle this more nicely
config = &RoutingTableConfig{
@ -91,7 +91,7 @@ func NewRoutingTable(logger *zap.Logger, localNode *overlay.NodeDossier, kdb, nd
bucketSize: config.BucketSize,
rcBucketSize: config.ReplacementCacheSize,
}
ok, err := rt.addNode(&localNode.Node)
ok, err := rt.addNode(context.TODO(), &localNode.Node)
if !ok || err != nil {
return nil, RoutingErr.New("could not add localNode to routing table: %s", err)
}
@ -131,8 +131,7 @@ func (rt *RoutingTable) CacheSize() int {
// GetNodes retrieves nodes within the same kbucket as the given node id
// Note: id doesn't need to be stored at time of search
func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) {
ctx := context.TODO()
func (rt *RoutingTable) GetNodes(ctx context.Context, id storj.NodeID) ([]*pb.Node, bool) {
defer mon.Task()(&ctx)(nil)
bID, err := rt.getKBucketID(ctx, id)
if err != nil {
@ -141,7 +140,7 @@ func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) {
if bID == (bucketID{}) {
return nil, false
}
unmarshaledNodes, err := rt.getUnmarshaledNodesFromBucket(bID)
unmarshaledNodes, err := rt.getUnmarshaledNodesFromBucket(ctx, bID)
if err != nil {
return nil, false
}
@ -149,8 +148,7 @@ func (rt *RoutingTable) GetNodes(id storj.NodeID) ([]*pb.Node, bool) {
}
// GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance
func (rt *RoutingTable) GetBucketIds() (_ storage.Keys, err error) {
ctx := context.TODO()
func (rt *RoutingTable) GetBucketIds(ctx context.Context) (_ storage.Keys, err error) {
defer mon.Task()(&ctx)(&err)
kbuckets, err := rt.kadBucketDB.List(ctx, nil, 0)
@ -161,8 +159,7 @@ func (rt *RoutingTable) GetBucketIds() (_ storage.Keys, err error) {
}
// DumpNodes iterates through all nodes in the nodeBucketDB and marshals them to &pb.Nodes, then returns them
func (rt *RoutingTable) DumpNodes() (_ []*pb.Node, err error) {
ctx := context.TODO()
func (rt *RoutingTable) DumpNodes(ctx context.Context) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
var nodes []*pb.Node
@ -187,8 +184,7 @@ func (rt *RoutingTable) DumpNodes() (_ []*pb.Node, err error) {
// FindNear returns the node corresponding to the provided nodeID
// returns all Nodes (excluding self) closest via XOR to the provided nodeID up to the provided limit
func (rt *RoutingTable) FindNear(target storj.NodeID, limit int) (_ []*pb.Node, err error) {
ctx := context.TODO()
func (rt *RoutingTable) FindNear(ctx context.Context, target storj.NodeID, limit int) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
closestNodes := make([]*pb.Node, 0, limit+1)
err = rt.iterateNodes(ctx, storj.NodeID{}, func(ctx context.Context, newID storj.NodeID, protoNode []byte) error {
@ -217,8 +213,7 @@ func (rt *RoutingTable) FindNear(target storj.NodeID, limit int) (_ []*pb.Node,
// ConnectionSuccess updates or adds a node to the routing table when
// a successful connection is made to the node on the network
func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) ConnectionSuccess(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
// valid to connect to node without ID but don't store connection
if node.Id == (storj.NodeID{}) {
@ -230,13 +225,13 @@ func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) {
return RoutingErr.New("could not get node %s", err)
}
if v != nil {
err = rt.updateNode(node)
err = rt.updateNode(ctx, node)
if err != nil {
return RoutingErr.New("could not update node %s", err)
}
return nil
}
_, err = rt.addNode(node)
_, err = rt.addNode(ctx, node)
if err != nil {
return RoutingErr.New("could not add node %s", err)
}
@ -246,10 +241,9 @@ func (rt *RoutingTable) ConnectionSuccess(node *pb.Node) (err error) {
// ConnectionFailed removes a node from the routing table when
// a connection fails for the node on the network
func (rt *RoutingTable) ConnectionFailed(node *pb.Node) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) ConnectionFailed(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
err = rt.removeNode(node)
err = rt.removeNode(ctx, node)
if err != nil {
return RoutingErr.New("could not remove node %s", err)
}
@ -257,8 +251,7 @@ func (rt *RoutingTable) ConnectionFailed(node *pb.Node) (err error) {
}
// SetBucketTimestamp records the time of the last node lookup for a bucket
func (rt *RoutingTable) SetBucketTimestamp(bIDBytes []byte, now time.Time) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) SetBucketTimestamp(ctx context.Context, bIDBytes []byte, now time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
rt.mutex.Lock()
defer rt.mutex.Unlock()
@ -270,8 +263,7 @@ func (rt *RoutingTable) SetBucketTimestamp(bIDBytes []byte, now time.Time) (err
}
// GetBucketTimestamp retrieves time of the last node lookup for a bucket
func (rt *RoutingTable) GetBucketTimestamp(bIDBytes []byte) (_ time.Time, err error) {
ctx := context.TODO()
func (rt *RoutingTable) GetBucketTimestamp(ctx context.Context, bIDBytes []byte) (_ time.Time, err error) {
defer mon.Task()(&ctx)(&err)
t, err := rt.kadBucketDB.Get(ctx, bIDBytes)
@ -307,7 +299,7 @@ func (rt *RoutingTable) iterateNodes(ctx context.Context, start storj.NodeID, f
// ConnFailure implements the Transport failure function
func (rt *RoutingTable) ConnFailure(ctx context.Context, node *pb.Node, err error) {
err2 := rt.ConnectionFailed(node)
err2 := rt.ConnectionFailed(ctx, node)
if err2 != nil {
zap.L().Debug(fmt.Sprintf("error with ConnFailure hook %+v : %+v", err, err2))
}
@ -315,7 +307,7 @@ func (rt *RoutingTable) ConnFailure(ctx context.Context, node *pb.Node, err erro
// ConnSuccess implements the Transport success function
func (rt *RoutingTable) ConnSuccess(ctx context.Context, node *pb.Node) {
err := rt.ConnectionSuccess(node)
err := rt.ConnectionSuccess(ctx, node)
if err != nil {
zap.L().Debug("connection success error:", zap.Error(err))
}

View File

@ -18,8 +18,7 @@ import (
// addNode attempts to add a new contact to the routing table
// Requires node not already in table
// Returns true if node was added successfully
func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
ctx := context.TODO()
func (rt *RoutingTable) addNode(ctx context.Context, node *pb.Node) (_ bool, err error) {
defer mon.Task()(&ctx)(&err)
rt.mutex.Lock()
defer rt.mutex.Unlock()
@ -29,7 +28,7 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
if err != nil {
return false, RoutingErr.New("could not create initial K bucket: %s", err)
}
err = rt.putNode(node)
err = rt.putNode(ctx, node)
if err != nil {
return false, RoutingErr.New("could not add initial node to nodeBucketDB: %s", err)
}
@ -39,22 +38,22 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
if err != nil {
return false, RoutingErr.New("could not getKBucketID: %s", err)
}
hasRoom, err := rt.kadBucketHasRoom(kadBucketID)
hasRoom, err := rt.kadBucketHasRoom(ctx, kadBucketID)
if err != nil {
return false, err
}
containsLocal, err := rt.kadBucketContainsLocalNode(kadBucketID)
containsLocal, err := rt.kadBucketContainsLocalNode(ctx, kadBucketID)
if err != nil {
return false, err
}
withinK, err := rt.wouldBeInNearestK(node.Id)
withinK, err := rt.wouldBeInNearestK(ctx, node.Id)
if err != nil {
return false, RoutingErr.New("could not determine if node is within k: %s", err)
}
for !hasRoom {
if containsLocal || withinK {
depth, err := rt.determineLeafDepth(kadBucketID)
depth, err := rt.determineLeafDepth(ctx, kadBucketID)
if err != nil {
return false, RoutingErr.New("could not determine leaf depth: %s", err)
}
@ -67,11 +66,11 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
if err != nil {
return false, RoutingErr.New("could not get k bucket Id within add node split bucket checks: %s", err)
}
hasRoom, err = rt.kadBucketHasRoom(kadBucketID)
hasRoom, err = rt.kadBucketHasRoom(ctx, kadBucketID)
if err != nil {
return false, err
}
containsLocal, err = rt.kadBucketContainsLocalNode(kadBucketID)
containsLocal, err = rt.kadBucketContainsLocalNode(ctx, kadBucketID)
if err != nil {
return false, err
}
@ -81,7 +80,7 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
return false, nil
}
}
err = rt.putNode(node)
err = rt.putNode(ctx, node)
if err != nil {
return false, RoutingErr.New("could not add node to nodeBucketDB: %s", err)
}
@ -94,18 +93,16 @@ func (rt *RoutingTable) addNode(node *pb.Node) (_ bool, err error) {
// updateNode will update the node information given that
// the node is already in the routing table.
func (rt *RoutingTable) updateNode(node *pb.Node) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) updateNode(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
if err := rt.putNode(node); err != nil {
if err := rt.putNode(ctx, node); err != nil {
return RoutingErr.New("could not update node: %v", err)
}
return nil
}
// removeNode will remove churned nodes and replace those entries with nodes from the replacement cache.
func (rt *RoutingTable) removeNode(node *pb.Node) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) removeNode(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
rt.mutex.Lock()
defer rt.mutex.Unlock()
@ -142,7 +139,7 @@ func (rt *RoutingTable) removeNode(node *pb.Node) (err error) {
if len(nodes) == 0 {
return nil
}
err = rt.putNode(nodes[len(nodes)-1])
err = rt.putNode(ctx, nodes[len(nodes)-1])
if err != nil {
return err
}
@ -152,8 +149,7 @@ func (rt *RoutingTable) removeNode(node *pb.Node) (err error) {
}
// putNode: helper, adds or updates Node and ID to nodeBucketDB
func (rt *RoutingTable) putNode(node *pb.Node) (err error) {
ctx := context.TODO()
func (rt *RoutingTable) putNode(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
v, err := proto.Marshal(node)
if err != nil {
@ -203,8 +199,9 @@ func (rt *RoutingTable) getKBucketID(ctx context.Context, nodeID storj.NodeID) (
}
// wouldBeInNearestK: helper, returns true if the node in question is within the nearest k from local node
func (rt *RoutingTable) wouldBeInNearestK(nodeID storj.NodeID) (bool, error) {
closestNodes, err := rt.FindNear(rt.self.Id, rt.bucketSize)
func (rt *RoutingTable) wouldBeInNearestK(ctx context.Context, nodeID storj.NodeID) (_ bool, err error) {
defer mon.Task()(&ctx)(&err)
closestNodes, err := rt.FindNear(ctx, rt.self.Id, rt.bucketSize)
if err != nil {
return false, RoutingErr.Wrap(err)
}
@ -224,8 +221,7 @@ func (rt *RoutingTable) wouldBeInNearestK(nodeID storj.NodeID) (bool, error) {
}
// kadBucketContainsLocalNode returns true if the kbucket in question contains the local node
func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (_ bool, err error) {
ctx := context.TODO()
func (rt *RoutingTable) kadBucketContainsLocalNode(ctx context.Context, queryID bucketID) (_ bool, err error) {
defer mon.Task()(&ctx)(&err)
bID, err := rt.getKBucketID(ctx, rt.self.Id)
if err != nil {
@ -235,8 +231,8 @@ func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (_ bool, er
}
// kadBucketHasRoom: helper, returns true if it has fewer than k nodes
func (rt *RoutingTable) kadBucketHasRoom(bID bucketID) (bool, error) {
nodes, err := rt.getNodeIDsWithinKBucket(bID)
func (rt *RoutingTable) kadBucketHasRoom(ctx context.Context, bID bucketID) (_ bool, err error) {
nodes, err := rt.getNodeIDsWithinKBucket(ctx, bID)
if err != nil {
return false, err
}
@ -247,10 +243,9 @@ func (rt *RoutingTable) kadBucketHasRoom(bID bucketID) (bool, error) {
}
// getNodeIDsWithinKBucket: helper, returns a collection of all the node ids contained within the kbucket
func (rt *RoutingTable) getNodeIDsWithinKBucket(bID bucketID) (_ storj.NodeIDList, err error) {
ctx := context.TODO()
func (rt *RoutingTable) getNodeIDsWithinKBucket(ctx context.Context, bID bucketID) (_ storj.NodeIDList, err error) {
defer mon.Task()(&ctx)(&err)
endpoints, err := rt.getKBucketRange(bID)
endpoints, err := rt.getKBucketRange(ctx, bID)
if err != nil {
return nil, err
}
@ -271,8 +266,7 @@ func (rt *RoutingTable) getNodeIDsWithinKBucket(bID bucketID) (_ storj.NodeIDLis
}
// getNodesFromIDsBytes: helper, returns array of encoded nodes from node ids
func (rt *RoutingTable) getNodesFromIDsBytes(nodeIDs storj.NodeIDList) (_ []*pb.Node, err error) {
ctx := context.TODO()
func (rt *RoutingTable) getNodesFromIDsBytes(ctx context.Context, nodeIDs storj.NodeIDList) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
var marshaledNodes []storage.Value
for _, v := range nodeIDs {
@ -300,12 +294,13 @@ func unmarshalNodes(nodes []storage.Value) ([]*pb.Node, error) {
}
// getUnmarshaledNodesFromBucket: helper, gets nodes within kbucket
func (rt *RoutingTable) getUnmarshaledNodesFromBucket(bID bucketID) ([]*pb.Node, error) {
nodeIDsBytes, err := rt.getNodeIDsWithinKBucket(bID)
func (rt *RoutingTable) getUnmarshaledNodesFromBucket(ctx context.Context, bID bucketID) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
nodeIDsBytes, err := rt.getNodeIDsWithinKBucket(ctx, bID)
if err != nil {
return []*pb.Node{}, RoutingErr.New("could not get nodeIds within kbucket %s", err)
}
nodes, err := rt.getNodesFromIDsBytes(nodeIDsBytes)
nodes, err := rt.getNodesFromIDsBytes(ctx, nodeIDsBytes)
if err != nil {
return []*pb.Node{}, RoutingErr.New("could not get node values %s", err)
}
@ -313,8 +308,7 @@ func (rt *RoutingTable) getUnmarshaledNodesFromBucket(bID bucketID) ([]*pb.Node,
}
// getKBucketRange: helper, returns the left and right endpoints of the range of node ids contained within the bucket
func (rt *RoutingTable) getKBucketRange(bID bucketID) (_ []bucketID, err error) {
ctx := context.TODO()
func (rt *RoutingTable) getKBucketRange(ctx context.Context, bID bucketID) (_ []bucketID, err error) {
defer mon.Task()(&ctx)(&err)
previousBucket := bucketID{}
endpoints := []bucketID{}
@ -340,8 +334,9 @@ func (rt *RoutingTable) getKBucketRange(bID bucketID) (_ []bucketID, err error)
// determineLeafDepth determines the level of the bucket id in question.
// Eg level 0 means there is only 1 bucket, level 1 means the bucket has been split once, and so on
func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) {
bucketRange, err := rt.getKBucketRange(bID)
func (rt *RoutingTable) determineLeafDepth(ctx context.Context, bID bucketID) (_ int, err error) {
defer mon.Task()(&ctx)(&err)
bucketRange, err := rt.getKBucketRange(ctx, bID)
if err != nil {
return -1, RoutingErr.New("could not get k bucket range %s", err)
}

View File

@ -5,6 +5,7 @@ package kademlia
import (
"bytes"
"context"
"sync"
"testing"
"time"
@ -30,7 +31,7 @@ type routingTableOpts struct {
}
// newTestRoutingTable returns a newly configured instance of a RoutingTable
func newTestRoutingTable(local *overlay.NodeDossier, opts routingTableOpts) (*RoutingTable, error) {
func newTestRoutingTable(ctx context.Context, local *overlay.NodeDossier, opts routingTableOpts) (*RoutingTable, error) {
if opts.bucketSize == 0 {
opts.bucketSize = 6
}
@ -50,34 +51,34 @@ func newTestRoutingTable(local *overlay.NodeDossier, opts routingTableOpts) (*Ro
bucketSize: opts.bucketSize,
rcBucketSize: opts.cacheSize,
}
ok, err := rt.addNode(&local.Node)
ok, err := rt.addNode(ctx, &local.Node)
if !ok || err != nil {
return nil, RoutingErr.New("could not add localNode to routing table: %s", err)
}
return rt, nil
}
func createRoutingTableWith(localNodeID storj.NodeID, opts routingTableOpts) *RoutingTable {
func createRoutingTableWith(ctx context.Context, localNodeID storj.NodeID, opts routingTableOpts) *RoutingTable {
if localNodeID == (storj.NodeID{}) {
panic("empty local node id")
}
local := &overlay.NodeDossier{Node: pb.Node{Id: localNodeID}}
rt, err := newTestRoutingTable(local, opts)
rt, err := newTestRoutingTable(ctx, local, opts)
if err != nil {
panic(err)
}
return rt
}
func createRoutingTable(localNodeID storj.NodeID) *RoutingTable {
return createRoutingTableWith(localNodeID, routingTableOpts{})
func createRoutingTable(ctx context.Context, localNodeID storj.NodeID) *RoutingTable {
return createRoutingTableWith(ctx, localNodeID, routingTableOpts{})
}
func TestAddNode(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("OO"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("OO"))
defer ctx.Check(rt.Close)
cases := []struct {
@ -204,14 +205,14 @@ func TestAddNode(t *testing.T) {
for _, c := range cases {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
ok, err := rt.addNode(testCase.node)
ok, err := rt.addNode(ctx, testCase.node)
require.NoError(t, err)
require.Equal(t, testCase.added, ok)
kadKeys, err := rt.kadBucketDB.List(ctx, nil, 0)
require.NoError(t, err)
for i, v := range kadKeys {
require.True(t, bytes.Equal(testCase.kadIDs[i], v[:2]))
ids, err := rt.getNodeIDsWithinKBucket(keyToBucketID(v))
ids, err := rt.getNodeIDsWithinKBucket(ctx, keyToBucketID(v))
require.NoError(t, err)
require.True(t, len(ids) == len(testCase.nodeIDs[i]))
for j, id := range ids {
@ -232,10 +233,10 @@ func TestAddNode(t *testing.T) {
func TestUpdateNode(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
node := teststorj.MockNode("BB")
ok, err := rt.addNode(node)
ok, err := rt.addNode(ctx, node)
assert.True(t, ok)
assert.NoError(t, err)
val, err := rt.nodeBucketDB.Get(ctx, node.Id.Bytes())
@ -246,7 +247,7 @@ func TestUpdateNode(t *testing.T) {
assert.Nil(t, x)
node.Address = &pb.NodeAddress{Address: "BB"}
err = rt.updateNode(node)
err = rt.updateNode(ctx, node)
assert.NoError(t, err)
val, err = rt.nodeBucketDB.Get(ctx, node.Id.Bytes())
assert.NoError(t, err)
@ -259,11 +260,11 @@ func TestUpdateNode(t *testing.T) {
func TestRemoveNode(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
kadBucketID := firstBucketID
node := teststorj.MockNode("BB")
ok, err := rt.addNode(node)
ok, err := rt.addNode(ctx, node)
assert.True(t, ok)
assert.NoError(t, err)
val, err := rt.nodeBucketDB.Get(ctx, node.Id.Bytes())
@ -271,7 +272,7 @@ func TestRemoveNode(t *testing.T) {
assert.NotNil(t, val)
node2 := teststorj.MockNode("CC")
rt.addToReplacementCache(kadBucketID, node2)
err = rt.removeNode(node)
err = rt.removeNode(ctx, node)
assert.NoError(t, err)
val, err = rt.nodeBucketDB.Get(ctx, node.Id.Bytes())
assert.Nil(t, val)
@ -282,7 +283,7 @@ func TestRemoveNode(t *testing.T) {
assert.Equal(t, 0, len(rt.replacementCache[kadBucketID]))
//try to remove node not in rt
err = rt.removeNode(&pb.Node{
err = rt.removeNode(ctx, &pb.Node{
Id: teststorj.NodeIDFromString("DD"),
Address: &pb.NodeAddress{Address: "address:1"},
})
@ -293,7 +294,7 @@ func TestCreateOrUpdateKBucket(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
id := bucketID{255, 255}
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
err := rt.createOrUpdateKBucket(ctx, id, time.Now())
assert.NoError(t, err)
@ -308,7 +309,7 @@ func TestGetKBucketID(t *testing.T) {
defer ctx.Cleanup()
kadIDA := bucketID{255, 255}
nodeIDA := teststorj.NodeIDFromString("AA")
rt := createRoutingTable(nodeIDA)
rt := createRoutingTable(ctx, nodeIDA)
defer ctx.Check(rt.Close)
keyA, err := rt.getKBucketID(ctx, nodeIDA)
assert.NoError(t, err)
@ -318,7 +319,7 @@ func TestGetKBucketID(t *testing.T) {
func TestWouldBeInNearestK(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTableWith(storj.NodeID{127, 255}, routingTableOpts{bucketSize: 2})
rt := createRoutingTableWith(ctx, storj.NodeID{127, 255}, routingTableOpts{bucketSize: 2})
defer ctx.Check(rt.Close)
cases := []struct {
@ -350,7 +351,7 @@ func TestWouldBeInNearestK(t *testing.T) {
for _, c := range cases {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
result, err := rt.wouldBeInNearestK(testCase.nodeID)
result, err := rt.wouldBeInNearestK(ctx, testCase.nodeID)
assert.NoError(t, err)
assert.Equal(t, testCase.closest, result)
assert.NoError(t, rt.nodeBucketDB.Put(ctx, testCase.nodeID.Bytes(), []byte("")))
@ -362,7 +363,7 @@ func TestKadBucketContainsLocalNode(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
rt := createRoutingTable(nodeIDA)
rt := createRoutingTable(ctx, nodeIDA)
defer ctx.Check(rt.Close)
kadIDA := firstBucketID
var kadIDB bucketID
@ -371,9 +372,9 @@ func TestKadBucketContainsLocalNode(t *testing.T) {
now := time.Now()
err := rt.createOrUpdateKBucket(ctx, kadIDB, now)
assert.NoError(t, err)
resultTrue, err := rt.kadBucketContainsLocalNode(kadIDA)
resultTrue, err := rt.kadBucketContainsLocalNode(ctx, kadIDA)
assert.NoError(t, err)
resultFalse, err := rt.kadBucketContainsLocalNode(kadIDB)
resultFalse, err := rt.kadBucketContainsLocalNode(ctx, kadIDB)
assert.NoError(t, err)
assert.True(t, resultTrue)
assert.False(t, resultFalse)
@ -383,7 +384,7 @@ func TestKadBucketHasRoom(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
node1 := storj.NodeID{255, 255}
rt := createRoutingTable(node1)
rt := createRoutingTable(ctx, node1)
defer ctx.Check(rt.Close)
kadIDA := firstBucketID
node2 := storj.NodeID{191, 255}
@ -391,7 +392,7 @@ func TestKadBucketHasRoom(t *testing.T) {
node4 := storj.NodeID{63, 255}
node5 := storj.NodeID{159, 255}
node6 := storj.NodeID{0, 127}
resultA, err := rt.kadBucketHasRoom(kadIDA)
resultA, err := rt.kadBucketHasRoom(ctx, kadIDA)
assert.NoError(t, err)
assert.True(t, resultA)
assert.NoError(t, rt.nodeBucketDB.Put(ctx, node2.Bytes(), []byte("")))
@ -399,7 +400,7 @@ func TestKadBucketHasRoom(t *testing.T) {
assert.NoError(t, rt.nodeBucketDB.Put(ctx, node4.Bytes(), []byte("")))
assert.NoError(t, rt.nodeBucketDB.Put(ctx, node5.Bytes(), []byte("")))
assert.NoError(t, rt.nodeBucketDB.Put(ctx, node6.Bytes(), []byte("")))
resultB, err := rt.kadBucketHasRoom(kadIDA)
resultB, err := rt.kadBucketHasRoom(ctx, kadIDA)
assert.NoError(t, err)
assert.False(t, resultB)
}
@ -408,7 +409,7 @@ func TestGetNodeIDsWithinKBucket(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
rt := createRoutingTable(nodeIDA)
rt := createRoutingTable(ctx, nodeIDA)
defer ctx.Check(rt.Close)
kadIDA := firstBucketID
var kadIDB bucketID
@ -440,7 +441,7 @@ func TestGetNodeIDsWithinKBucket(t *testing.T) {
for _, c := range cases {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
n, err := rt.getNodeIDsWithinKBucket(testCase.kadID)
n, err := rt.getNodeIDsWithinKBucket(ctx, testCase.kadID)
assert.NoError(t, err)
for i, id := range testCase.expected {
assert.True(t, id.Equal(n[i].Bytes()))
@ -461,7 +462,7 @@ func TestGetNodesFromIDs(t *testing.T) {
assert.NoError(t, err)
c, err := proto.Marshal(nodeC)
assert.NoError(t, err)
rt := createRoutingTable(nodeA.Id)
rt := createRoutingTable(ctx, nodeA.Id)
defer ctx.Check(rt.Close)
assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeA.Id.Bytes(), a))
@ -471,7 +472,7 @@ func TestGetNodesFromIDs(t *testing.T) {
nodeKeys, err := rt.nodeBucketDB.List(ctx, nil, 0)
assert.NoError(t, err)
values, err := rt.getNodesFromIDsBytes(teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...))
values, err := rt.getNodesFromIDsBytes(ctx, teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...))
assert.NoError(t, err)
for i, n := range expected {
assert.True(t, bytes.Equal(n.Id.Bytes(), values[i].Id.Bytes()))
@ -491,14 +492,14 @@ func TestUnmarshalNodes(t *testing.T) {
assert.NoError(t, err)
c, err := proto.Marshal(nodeC)
assert.NoError(t, err)
rt := createRoutingTable(nodeA.Id)
rt := createRoutingTable(ctx, nodeA.Id)
defer ctx.Check(rt.Close)
assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeA.Id.Bytes(), a))
assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeB.Id.Bytes(), b))
assert.NoError(t, rt.nodeBucketDB.Put(ctx, nodeC.Id.Bytes(), c))
nodeKeys, err := rt.nodeBucketDB.List(ctx, nil, 0)
assert.NoError(t, err)
nodes, err := rt.getNodesFromIDsBytes(teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...))
nodes, err := rt.getNodesFromIDsBytes(ctx, teststorj.NodeIDsFromBytes(nodeKeys.ByteSlices()...))
assert.NoError(t, err)
expected := []*pb.Node{nodeA, nodeB, nodeC}
for i, v := range expected {
@ -510,17 +511,17 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
nodeA := teststorj.MockNode("AA")
rt := createRoutingTable(nodeA.Id)
rt := createRoutingTable(ctx, nodeA.Id)
defer ctx.Check(rt.Close)
bucketID := firstBucketID
nodeB := teststorj.MockNode("BB")
nodeC := teststorj.MockNode("CC")
var err error
_, err = rt.addNode(nodeB)
_, err = rt.addNode(ctx, nodeB)
assert.NoError(t, err)
_, err = rt.addNode(nodeC)
_, err = rt.addNode(ctx, nodeC)
assert.NoError(t, err)
nodes, err := rt.getUnmarshaledNodesFromBucket(bucketID)
nodes, err := rt.getUnmarshaledNodesFromBucket(ctx, bucketID)
expected := []*pb.Node{nodeA, nodeB, nodeC}
assert.NoError(t, err)
for i, v := range expected {
@ -531,7 +532,7 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
func TestGetKBucketRange(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
idA := storj.NodeID{255, 255}
idB := storj.NodeID{127, 255}
@ -560,7 +561,7 @@ func TestGetKBucketRange(t *testing.T) {
for _, c := range cases {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
ep, err := rt.getKBucketRange(keyToBucketID(testCase.id.Bytes()))
ep, err := rt.getKBucketRange(ctx, keyToBucketID(testCase.id.Bytes()))
assert.NoError(t, err)
for i, k := range testCase.expected {
assert.True(t, k.Equal(ep[i][:]))
@ -578,7 +579,7 @@ func TestBucketIDZeroValue(t *testing.T) {
func TestDetermineLeafDepth(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
idA, idB, idC := firstBucketID, firstBucketID, firstBucketID
idA[0] = 255
@ -630,7 +631,7 @@ func TestDetermineLeafDepth(t *testing.T) {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
testCase.addNode()
d, err := rt.determineLeafDepth(testCase.id)
d, err := rt.determineLeafDepth(ctx, testCase.id)
assert.NoError(t, err)
assert.Equal(t, testCase.depth, d)
})
@ -640,7 +641,7 @@ func TestDetermineLeafDepth(t *testing.T) {
func TestSplitBucket(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
cases := []struct {
testID string

View File

@ -4,6 +4,7 @@
package kademlia
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -15,19 +16,19 @@ import (
"storj.io/storj/pkg/storj"
)
type routingCtor func(storj.NodeID, int, int, int) dht.RoutingTable
type routingCtor func(context.Context, storj.NodeID, int, int, int) dht.RoutingTable
func newRouting(self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable {
func newRouting(ctx context.Context, self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable {
if allowedFailures != 0 {
panic("failure counting currently unsupported")
}
return createRoutingTableWith(self, routingTableOpts{
return createRoutingTableWith(ctx, self, routingTableOpts{
bucketSize: bucketSize,
cacheSize: cacheSize,
})
}
func newTestRouting(self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable {
func newTestRouting(ctx context.Context, self storj.NodeID, bucketSize, cacheSize, allowedFailures int) dht.RoutingTable {
return testrouting.New(self, bucketSize, cacheSize, allowedFailures)
}
@ -39,12 +40,12 @@ func testTableInit(t *testing.T, routingCtor routingCtor) {
bucketSize := 5
cacheSize := 3
table := routingCtor(PadID("55", "5"), bucketSize, cacheSize, 0)
table := routingCtor(ctx, PadID("55", "5"), bucketSize, cacheSize, 0)
defer ctx.Check(table.Close)
require.Equal(t, bucketSize, table.K())
require.Equal(t, cacheSize, table.CacheSize())
nodes, err := table.FindNear(PadID("21", "0"), 3)
nodes, err := table.FindNear(ctx, PadID("21", "0"), 3)
require.NoError(t, err)
require.Equal(t, 0, len(nodes))
}
@ -55,13 +56,13 @@ func testTableBasic(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("5555", "5"), 5, 3, 0)
table := routingCtor(ctx, PadID("5555", "5"), 5, 3, 0)
defer ctx.Check(table.Close)
err := table.ConnectionSuccess(Node(PadID("5556", "5"), "address:1"))
err := table.ConnectionSuccess(ctx, Node(PadID("5556", "5"), "address:1"))
require.NoError(t, err)
nodes, err := table.FindNear(PadID("21", "0"), 3)
nodes, err := table.FindNear(ctx, PadID("21", "0"), 3)
require.NoError(t, err)
require.Equal(t, 1, len(nodes))
require.Equal(t, PadID("5556", "5"), nodes[0].Id)
@ -74,12 +75,12 @@ func testNoSelf(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("55", "5"), 5, 3, 0)
table := routingCtor(ctx, PadID("55", "5"), 5, 3, 0)
defer ctx.Check(table.Close)
err := table.ConnectionSuccess(Node(PadID("55", "5"), "address:2"))
err := table.ConnectionSuccess(ctx, Node(PadID("55", "5"), "address:2"))
require.NoError(t, err)
nodes, err := table.FindNear(PadID("21", "0"), 3)
nodes, err := table.FindNear(ctx, PadID("21", "0"), 3)
require.NoError(t, err)
require.Equal(t, 0, len(nodes))
}
@ -90,12 +91,12 @@ func testSplits(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("55", "5"), 5, 2, 0)
table := routingCtor(ctx, PadID("55", "5"), 5, 2, 0)
defer ctx.Check(table.Close)
for _, prefix2 := range "18" {
for _, prefix1 := range "a69c23f1d7eb5408" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "0")))
}
}
@ -108,7 +109,7 @@ func testSplits(t *testing.T, routingCtor routingCtor) {
// three bits should also not be full and have 4 nodes
// (40..., 48..., 50..., 58...). So we should be able to get no more than
// 18 nodes back
nodes, err := table.FindNear(PadID("55", "5"), 19)
nodes, err := table.FindNear(ctx, PadID("55", "5"), 19)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
// bucket 010 (same first three bits)
@ -140,23 +141,23 @@ func testSplits(t *testing.T, routingCtor routingCtor) {
// the gaps
// bucket 010 shouldn't have anything in its replacement cache
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0")))
// bucket 011 shouldn't have anything in its replacement cache
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("68", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("68", "0")))
// bucket 00 should have two things in its replacement cache, 18... is one of them
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("18", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("18", "0")))
// now just one thing in its replacement cache
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("31", "0")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("28", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("31", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("28", "0")))
// bucket 1 should have two things in its replacement cache
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("a1", "0")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "0")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("91", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("a1", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("91", "0")))
nodes, err = table.FindNear(PadID("55", "5"), 19)
nodes, err = table.FindNear(ctx, PadID("55", "5"), 19)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
// bucket 010
@ -187,12 +188,12 @@ func testUnbalanced(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("ff", "f"), 5, 2, 0)
table := routingCtor(ctx, PadID("ff", "f"), 5, 2, 0)
defer ctx.Check(table.Close)
for _, prefix1 := range "0123456789abcdef" {
for _, prefix2 := range "18" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "0")))
}
}
@ -202,7 +203,7 @@ func testUnbalanced(t *testing.T, routingCtor routingCtor) {
// would have forced every bucket to split, and we should have stored all
// possible nodes.
nodes, err := table.FindNear(PadID("ff", "f"), 33)
nodes, err := table.FindNear(ctx, PadID("ff", "f"), 33)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("f8", "0"), NodeFromPrefix("f1", "0"),
@ -230,24 +231,24 @@ func testQuery(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 5, 2, 0)
table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 0)
defer ctx.Check(table.Close)
for _, prefix2 := range "18" {
for _, prefix1 := range "b4f25c896de03a71" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "f")))
}
}
nodes, err := table.FindNear(PadID("c7139", "1"), 2)
nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 2)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
NodeFromPrefix("d1", "f"),
}, nodes)
nodes, err = table.FindNear(PadID("c7139", "1"), 7)
nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 7)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
@ -259,7 +260,7 @@ func testQuery(t *testing.T, routingCtor routingCtor) {
NodeFromPrefix("88", "f"),
}, nodes)
nodes, err = table.FindNear(PadID("c7139", "1"), 10)
nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 10)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
@ -281,18 +282,18 @@ func testFailureCounting(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 5, 2, 2)
table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 2)
defer ctx.Check(table.Close)
for _, prefix2 := range "18" {
for _, prefix1 := range "b4f25c896de03a71" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "f")))
}
}
nochange := func() {
nodes, err := table.FindNear(PadID("c7139", "1"), 7)
nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 7)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
@ -306,13 +307,13 @@ func testFailureCounting(t *testing.T, routingCtor routingCtor) {
}
nochange()
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f")))
nochange()
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f")))
nochange()
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("d1", "f")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("d1", "f")))
nodes, err := table.FindNear(PadID("c7139", "1"), 7)
nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 7)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
@ -331,26 +332,26 @@ func testUpdateBucket(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 5, 2, 0)
table := routingCtor(ctx, PadID("a3", "3"), 5, 2, 0)
defer ctx.Check(table.Close)
for _, prefix2 := range "18" {
for _, prefix1 := range "b4f25c896de03a71" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "f")))
}
}
nodes, err := table.FindNear(PadID("c7139", "1"), 1)
nodes, err := table.FindNear(ctx, PadID("c7139", "1"), 1)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("c1", "f"),
}, nodes)
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
Node(PadID("c1", "f"), "new-address:3")))
nodes, err = table.FindNear(PadID("c7139", "1"), 1)
nodes, err = table.FindNear(ctx, PadID("c7139", "1"), 1)
require.NoError(t, err)
require.Equal(t, 1, len(nodes))
require.Equal(t, PadID("c1", "f"), nodes[0].Id)
@ -363,18 +364,18 @@ func testUpdateCache(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 1, 1, 0)
table := routingCtor(ctx, PadID("a3", "3"), 1, 1, 0)
defer ctx.Check(table.Close)
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionSuccess(Node(PadID("01", "0"), "new-address:6")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, Node(PadID("01", "0"), "new-address:6")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0")))
nodes, err := table.FindNear(PadID("01", "0"), 4)
nodes, err := table.FindNear(ctx, PadID("01", "0"), 4)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
@ -390,16 +391,16 @@ func testFailureUnknownAddress(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 1, 1, 0)
table := routingCtor(ctx, PadID("a3", "3"), 1, 1, 0)
defer ctx.Check(table.Close)
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(Node(PadID("41", "0"), "address:2")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, Node(PadID("41", "0"), "address:2")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("41", "0")))
nodes, err := table.FindNear(PadID("01", "0"), 4)
nodes, err := table.FindNear(ctx, PadID("01", "0"), 4)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
@ -415,13 +416,13 @@ func testShrink(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("ff", "f"), 2, 2, 0)
table := routingCtor(ctx, PadID("ff", "f"), 2, 2, 0)
defer ctx.Check(table.Close)
// blow out the routing table
for _, prefix1 := range "0123456789abcdef" {
for _, prefix2 := range "18" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "0")))
}
}
@ -429,7 +430,7 @@ func testShrink(t *testing.T, routingCtor routingCtor) {
// delete some of the bad ones
for _, prefix1 := range "0123456789abcd" {
for _, prefix2 := range "18" {
require.NoError(t, table.ConnectionFailed(
require.NoError(t, table.ConnectionFailed(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "0")))
}
}
@ -437,13 +438,13 @@ func testShrink(t *testing.T, routingCtor routingCtor) {
// add back some nodes more balanced
for _, prefix1 := range "3a50" {
for _, prefix2 := range "19" {
require.NoError(t, table.ConnectionSuccess(
require.NoError(t, table.ConnectionSuccess(ctx,
NodeFromPrefix(string([]rune{prefix1, prefix2}), "0")))
}
}
// make sure table filled in alright
nodes, err := table.FindNear(PadID("ff", "f"), 13)
nodes, err := table.FindNear(ctx, PadID("ff", "f"), 13)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("f8", "0"),
@ -467,17 +468,17 @@ func testReplacementCacheOrder(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("a3", "3"), 1, 2, 0)
table := routingCtor(ctx, PadID("a3", "3"), 1, 2, 0)
defer ctx.Check(table.Close)
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("21", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("21", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("81", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("21", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("41", "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix("01", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("21", "0")))
nodes, err := table.FindNear(PadID("55", "5"), 4)
nodes, err := table.FindNear(ctx, PadID("55", "5"), 4)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
@ -493,16 +494,16 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("55", "55"), 2, 2, 0)
table := routingCtor(ctx, PadID("55", "55"), 2, 2, 0)
defer ctx.Check(table.Close)
for _, pad := range []string{"0", "1"} {
for _, prefix := range []string{"ff", "e1", "c1", "54", "56", "57"} {
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, pad)))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, pad)))
}
}
nodes, err := table.FindNear(PadID("55", "55"), 9)
nodes, err := table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -515,9 +516,9 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
NodeFromPrefix("e1", "0"),
}, nodes)
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("c1", "0")))
nodes, err = table.FindNear(PadID("55", "55"), 9)
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -529,8 +530,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
NodeFromPrefix("e1", "0"),
}, nodes)
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("ff", "0")))
nodes, err = table.FindNear(PadID("55", "55"), 9)
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("ff", "0")))
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -542,8 +543,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
NodeFromPrefix("e1", "0"),
}, nodes)
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("e1", "0")))
nodes, err = table.FindNear(PadID("55", "55"), 9)
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("e1", "0")))
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -555,8 +556,8 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
NodeFromPrefix("e1", "1"),
}, nodes)
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("e1", "1")))
nodes, err = table.FindNear(PadID("55", "55"), 9)
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("e1", "1")))
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -568,10 +569,10 @@ func testHealSplit(t *testing.T, routingCtor routingCtor) {
}, nodes)
for _, prefix := range []string{"ff", "e1", "c1", "54", "56", "57"} {
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, "2")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, "2")))
}
nodes, err = table.FindNear(PadID("55", "55"), 9)
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("54", "1"),
@ -591,23 +592,23 @@ func testFullDissimilarBucket(t *testing.T, routingCtor routingCtor) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
table := routingCtor(PadID("55", "55"), 2, 2, 0)
table := routingCtor(ctx, PadID("55", "55"), 2, 2, 0)
defer ctx.Check(table.Close)
for _, prefix := range []string{"d1", "c1", "f1", "e1"} {
require.NoError(t, table.ConnectionSuccess(NodeFromPrefix(prefix, "0")))
require.NoError(t, table.ConnectionSuccess(ctx, NodeFromPrefix(prefix, "0")))
}
nodes, err := table.FindNear(PadID("55", "55"), 9)
nodes, err := table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("d1", "0"),
NodeFromPrefix("c1", "0"),
}, nodes)
require.NoError(t, table.ConnectionFailed(NodeFromPrefix("c1", "0")))
require.NoError(t, table.ConnectionFailed(ctx, NodeFromPrefix("c1", "0")))
nodes, err = table.FindNear(PadID("55", "55"), 9)
nodes, err = table.FindNear(ctx, PadID("55", "55"), 9)
require.NoError(t, err)
requireNodesEqual(t, []*pb.Node{
NodeFromPrefix("d1", "0"),

View File

@ -5,6 +5,7 @@ package kademlia
import (
"bytes"
"context"
"fmt"
"math/rand"
"sort"
@ -25,7 +26,7 @@ func TestLocal(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
assert.Equal(t, rt.Local().Id.Bytes()[:2], []byte("AA"))
}
@ -34,7 +35,7 @@ func TestK(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
k := rt.K()
assert.Equal(t, rt.bucketSize, k)
@ -45,7 +46,7 @@ func TestCacheSize(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
expected := rt.rcBucketSize
result := rt.CacheSize()
@ -56,11 +57,11 @@ func TestGetBucket(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rt := createRoutingTable(teststorj.NodeIDFromString("AA"))
rt := createRoutingTable(ctx, teststorj.NodeIDFromString("AA"))
defer ctx.Check(rt.Close)
node := teststorj.MockNode("AA")
node2 := teststorj.MockNode("BB")
ok, err := rt.addNode(node2)
ok, err := rt.addNode(ctx, node2)
assert.True(t, ok)
assert.NoError(t, err)
@ -79,7 +80,7 @@ func TestGetBucket(t *testing.T) {
},
}
for i, v := range cases {
b, e := rt.GetNodes(node2.Id)
b, e := rt.GetNodes(ctx, node2.Id)
for j, w := range v.expected {
if !assert.True(t, bytes.Equal(w.Id.Bytes(), b[j].Id.Bytes())) {
t.Logf("case %v failed expected: ", i)
@ -97,14 +98,15 @@ func RandomNode() pb.Node {
return node
}
func TestKademliaFindNear(t *testing.T) {
ctx := context.Background()
testFunc := func(t *testing.T, testNodeCount, limit int) {
selfNode := RandomNode()
rt := createRoutingTable(selfNode.Id)
rt := createRoutingTable(ctx, selfNode.Id)
expectedIDs := make([]storj.NodeID, 0)
for x := 0; x < testNodeCount; x++ {
n := RandomNode()
ok, err := rt.addNode(&n)
ok, err := rt.addNode(ctx, &n)
require.NoError(t, err)
if ok { // buckets were full
expectedIDs = append(expectedIDs, n.Id)
@ -118,7 +120,7 @@ func TestKademliaFindNear(t *testing.T) {
targetNode.Id[storj.NodeIDSize-1] ^= 1 //flip lowest bit
sortByXOR(expectedIDs, targetNode.Id)
results, err := rt.FindNear(targetNode.Id, limit)
results, err := rt.FindNear(ctx, targetNode.Id, limit)
require.NoError(t, err)
counts := []int{len(expectedIDs), limit}
sort.Ints(counts)
@ -142,7 +144,7 @@ func TestConnectionSuccess(t *testing.T) {
defer ctx.Cleanup()
id := teststorj.NodeIDFromString("AA")
rt := createRoutingTable(id)
rt := createRoutingTable(ctx, id)
defer ctx.Check(rt.Close)
id2 := teststorj.NodeIDFromString("BB")
address1 := &pb.NodeAddress{Address: "a"}
@ -169,7 +171,7 @@ func TestConnectionSuccess(t *testing.T) {
for _, c := range cases {
testCase := c
t.Run(testCase.testID, func(t *testing.T) {
err := rt.ConnectionSuccess(testCase.node)
err := rt.ConnectionSuccess(ctx, testCase.node)
assert.NoError(t, err)
v, err := rt.nodeBucketDB.Get(ctx, testCase.id.Bytes())
assert.NoError(t, err)
@ -186,9 +188,9 @@ func TestConnectionFailed(t *testing.T) {
id := teststorj.NodeIDFromString("AA")
node := &pb.Node{Id: id}
rt := createRoutingTable(id)
rt := createRoutingTable(ctx, id)
defer ctx.Check(rt.Close)
err := rt.ConnectionFailed(node)
err := rt.ConnectionFailed(ctx, node)
assert.NoError(t, err)
v, err := rt.nodeBucketDB.Get(ctx, id.Bytes())
assert.Error(t, err)
@ -200,19 +202,19 @@ func TestSetBucketTimestamp(t *testing.T) {
defer ctx.Cleanup()
id := teststorj.NodeIDFromString("AA")
rt := createRoutingTable(id)
rt := createRoutingTable(ctx, id)
defer ctx.Check(rt.Close)
now := time.Now().UTC()
err := rt.createOrUpdateKBucket(ctx, keyToBucketID(id.Bytes()), now)
assert.NoError(t, err)
ti, err := rt.GetBucketTimestamp(id.Bytes())
ti, err := rt.GetBucketTimestamp(ctx, id.Bytes())
assert.Equal(t, now, ti)
assert.NoError(t, err)
now = time.Now().UTC()
err = rt.SetBucketTimestamp(id.Bytes(), now)
err = rt.SetBucketTimestamp(ctx, id.Bytes(), now)
assert.NoError(t, err)
ti, err = rt.GetBucketTimestamp(id.Bytes())
ti, err = rt.GetBucketTimestamp(ctx, id.Bytes())
assert.Equal(t, now, ti)
assert.NoError(t, err)
}
@ -222,12 +224,12 @@ func TestGetBucketTimestamp(t *testing.T) {
defer ctx.Cleanup()
id := teststorj.NodeIDFromString("AA")
rt := createRoutingTable(id)
rt := createRoutingTable(ctx, id)
defer ctx.Check(rt.Close)
now := time.Now().UTC()
err := rt.createOrUpdateKBucket(ctx, keyToBucketID(id.Bytes()), now)
assert.NoError(t, err)
ti, err := rt.GetBucketTimestamp(id.Bytes())
ti, err := rt.GetBucketTimestamp(ctx, id.Bytes())
assert.Equal(t, now, ti)
assert.NoError(t, err)
}

View File

@ -4,10 +4,13 @@
package testrouting
import (
"context"
"sort"
"sync"
"time"
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/pkg/dht"
"storj.io/storj/pkg/overlay"
"storj.io/storj/pkg/pb"
@ -15,6 +18,10 @@ import (
"storj.io/storj/storage"
)
var (
mon = monkit.Package()
)
type nodeData struct {
node *pb.Node
ordering int64
@ -63,7 +70,8 @@ func (t *Table) CacheSize() int { return t.cacheSize }
// ConnectionSuccess should be called whenever a node is successfully connected
// to. It will add or update the node's entry in the routing table.
func (t *Table) ConnectionSuccess(node *pb.Node) error {
func (t *Table) ConnectionSuccess(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
t.mu.Lock()
defer t.mu.Unlock()
@ -100,7 +108,8 @@ func (t *Table) ConnectionSuccess(node *pb.Node) error {
// ConnectionFailed should be called whenever a node can't be contacted.
// If a node fails more than allowedFailures times, it will be removed from
// the routing table. The failure count is reset every successful connection.
func (t *Table) ConnectionFailed(node *pb.Node) error {
func (t *Table) ConnectionFailed(ctx context.Context, node *pb.Node) (err error) {
defer mon.Task()(&ctx)(&err)
t.mu.Lock()
defer t.mu.Unlock()
@ -122,7 +131,8 @@ func (t *Table) ConnectionFailed(node *pb.Node) error {
// FindNear will return up to limit nodes in the routing table ordered by
// kademlia xor distance from the given id.
func (t *Table) FindNear(id storj.NodeID, limit int) ([]*pb.Node, error) {
func (t *Table) FindNear(ctx context.Context, id storj.NodeID, limit int) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
t.mu.Lock()
defer t.mu.Unlock()
@ -180,17 +190,17 @@ func (t *Table) GetNodes(id storj.NodeID) (nodes []*pb.Node, ok bool) {
}
// GetBucketIds returns a storage.Keys type of bucket ID's in the Kademlia instance
func (t *Table) GetBucketIds() (storage.Keys, error) {
func (t *Table) GetBucketIds(context.Context) (storage.Keys, error) {
panic("TODO")
}
// SetBucketTimestamp records the time of the last node lookup for a bucket
func (t *Table) SetBucketTimestamp(id []byte, now time.Time) error {
func (t *Table) SetBucketTimestamp(context.Context, []byte, time.Time) error {
panic("TODO")
}
// GetBucketTimestamp retrieves time of the last node lookup for a bucket
func (t *Table) GetBucketTimestamp(id []byte) (time.Time, error) {
func (t *Table) GetBucketTimestamp(context.Context, []byte) (time.Time, error) {
panic("TODO")
}

View File

@ -3,9 +3,7 @@
package testrouting
import (
"storj.io/storj/pkg/storj"
)
import "storj.io/storj/pkg/storj"
type nodeDataDistanceSorter struct {
self storj.NodeID