pkg/kademlia: simplify code (#958)
This commit is contained in:
parent
a2fa5c4c5a
commit
26c2564bd8
@ -201,7 +201,7 @@ func TestRefresh(t *testing.T) {
|
||||
//turn back time for only bucket
|
||||
rt := k.routingTable
|
||||
now := time.Now().UTC()
|
||||
bID := rt.createFirstBucketID() //always exists
|
||||
bID := firstBucketID //always exists
|
||||
err := rt.SetBucketTimestamp(bID[:], now.Add(-2*time.Hour))
|
||||
assert.NoError(t, err)
|
||||
//refresh should call FindNode, updating the time
|
||||
|
@ -231,12 +231,3 @@ func (queue *discoveryQueue) Len() int {
|
||||
|
||||
return len(queue.items)
|
||||
}
|
||||
|
||||
// xorNodeID returns the xor of each byte in NodeID
|
||||
func xorNodeID(a, b storj.NodeID) storj.NodeID {
|
||||
r := storj.NodeID{}
|
||||
for i, av := range a {
|
||||
r[i] = av ^ b[i]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
@ -10,16 +10,18 @@ import (
|
||||
|
||||
"storj.io/storj/internal/teststorj"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/storj"
|
||||
)
|
||||
|
||||
func TestAddToReplacementCache(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{244, 255}))
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{244, 255})
|
||||
defer cleanup()
|
||||
kadBucketID := keyToBucketID(teststorj.NodeIDFromBytes([]byte{255, 255}).Bytes())
|
||||
|
||||
kadBucketID := bucketID{255, 255}
|
||||
node1 := teststorj.MockNode(string([]byte{233, 255}))
|
||||
rt.addToReplacementCache(kadBucketID, node1)
|
||||
assert.Equal(t, []*pb.Node{node1}, rt.replacementCache[kadBucketID])
|
||||
kadBucketID2 := keyToBucketID(teststorj.NodeIDFromBytes([]byte{127, 255}).Bytes())
|
||||
kadBucketID2 := bucketID{127, 255}
|
||||
node2 := teststorj.MockNode(string([]byte{100, 255}))
|
||||
node3 := teststorj.MockNode(string([]byte{90, 255}))
|
||||
node4 := teststorj.MockNode(string([]byte{80, 255}))
|
||||
|
@ -32,6 +32,20 @@ var RoutingErr = errs.Class("routing table error")
|
||||
// Bucket IDs exist in the same address space as node IDs
|
||||
type bucketID [len(storj.NodeID{})]byte
|
||||
|
||||
var firstBucketID = bucketID{
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
}
|
||||
|
||||
var emptyBucketID = bucketID{}
|
||||
|
||||
// RoutingTable implements the RoutingTable interface
|
||||
type RoutingTable struct {
|
||||
log *zap.Logger
|
||||
@ -130,14 +144,14 @@ func (rt *RoutingTable) FindNear(id storj.NodeID, limit int) (nodes []*pb.Node,
|
||||
if err != nil {
|
||||
return nodes, RoutingErr.New("could not get node ids %s", err)
|
||||
}
|
||||
sortByXOR(nodeIDsKeys, id.Bytes())
|
||||
if len(nodeIDsKeys) >= limit {
|
||||
nodeIDsKeys = nodeIDsKeys[:limit]
|
||||
}
|
||||
nodeIDs, err := storj.NodeIDsFromBytes(nodeIDsKeys.ByteSlices())
|
||||
if err != nil {
|
||||
return nodes, RoutingErr.Wrap(err)
|
||||
}
|
||||
sortByXOR(nodeIDs, id)
|
||||
if len(nodeIDs) >= limit {
|
||||
nodeIDs = nodeIDs[:limit]
|
||||
}
|
||||
|
||||
nodes, err = rt.getNodesFromIDsBytes(nodeIDs)
|
||||
if err != nil {
|
||||
|
@ -6,11 +6,9 @@ package kademlia
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/storj"
|
||||
@ -23,10 +21,9 @@ import (
|
||||
func (rt *RoutingTable) addNode(node *pb.Node) (bool, error) {
|
||||
rt.mutex.Lock()
|
||||
defer rt.mutex.Unlock()
|
||||
nodeIDBytes := node.Id.Bytes()
|
||||
|
||||
if bytes.Equal(nodeIDBytes, rt.self.Id.Bytes()) {
|
||||
err := rt.createOrUpdateKBucket(rt.createFirstBucketID(), time.Now())
|
||||
if node.Id == rt.self.Id {
|
||||
err := rt.createOrUpdateKBucket(firstBucketID, time.Now())
|
||||
if err != nil {
|
||||
return false, RoutingErr.New("could not create initial K bucket: %s", err)
|
||||
}
|
||||
@ -177,89 +174,14 @@ func (rt *RoutingTable) getKBucketID(nodeID storj.NodeID) (bucketID, error) {
|
||||
return bucketID{}, RoutingErr.New("could not find k bucket")
|
||||
}
|
||||
|
||||
// compareByXor compares left, right xorred by reference
|
||||
func compareByXor(left, right, reference storage.Key) int {
|
||||
n := len(reference)
|
||||
if n > len(left) {
|
||||
n = len(left)
|
||||
}
|
||||
if n > len(right) {
|
||||
n = len(right)
|
||||
}
|
||||
left = left[:n]
|
||||
right = right[:n]
|
||||
reference = reference[:n]
|
||||
|
||||
for i, r := range reference {
|
||||
a, b := left[i]^r, right[i]^r
|
||||
if a != b {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func sortByXOR(nodeIDs storage.Keys, ref storage.Key) {
|
||||
sort.Slice(nodeIDs, func(i, k int) bool {
|
||||
return compareByXor(nodeIDs[i], nodeIDs[k], ref) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func nodeIDsToKeys(ids storj.NodeIDList) (nodeIDKeys storage.Keys) {
|
||||
for _, n := range ids {
|
||||
nodeIDKeys = append(nodeIDKeys, n.Bytes())
|
||||
}
|
||||
return nodeIDKeys
|
||||
}
|
||||
|
||||
func keysToNodeIDs(keys storage.Keys) (ids storj.NodeIDList, err error) {
|
||||
var idErrs []error
|
||||
for _, k := range keys {
|
||||
id, err := storj.NodeIDFromBytes(k[:])
|
||||
if err != nil {
|
||||
idErrs = append(idErrs, err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := errs.Combine(idErrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func keyToBucketID(key storage.Key) (bID bucketID) {
|
||||
copy(bID[:], key)
|
||||
return bID
|
||||
}
|
||||
|
||||
// determineFurthestIDWithinK: helper, determines the furthest node within the k closest to local node
|
||||
func (rt *RoutingTable) determineFurthestIDWithinK(nodeIDs storj.NodeIDList) (storj.NodeID, error) {
|
||||
nodeIDKeys := nodeIDsToKeys(nodeIDs)
|
||||
sortByXOR(nodeIDKeys, rt.self.Id.Bytes())
|
||||
func (rt *RoutingTable) determineFurthestIDWithinK(nodeIDs storj.NodeIDList) storj.NodeID {
|
||||
nodeIDs = cloneNodeIDs(nodeIDs)
|
||||
sortByXOR(nodeIDs, rt.self.Id)
|
||||
if len(nodeIDs) < rt.bucketSize+1 { //adding 1 since we're not including local node in closest k
|
||||
return storj.NodeIDFromBytes(nodeIDKeys[len(nodeIDKeys)-1])
|
||||
return nodeIDs[len(nodeIDs)-1]
|
||||
}
|
||||
return storj.NodeIDFromBytes(nodeIDKeys[rt.bucketSize])
|
||||
}
|
||||
|
||||
// xorTwoIds: helper, finds the xor distance between two byte slices
|
||||
func xorTwoIds(id, comparisonID []byte) []byte {
|
||||
var xorArr []byte
|
||||
s := len(id)
|
||||
if s > len(comparisonID) {
|
||||
s = len(comparisonID)
|
||||
}
|
||||
|
||||
for i := 0; i < s; i++ {
|
||||
xor := id[i] ^ comparisonID[i]
|
||||
xorArr = append(xorArr, xor)
|
||||
}
|
||||
return xorArr
|
||||
return nodeIDs[rt.bucketSize]
|
||||
}
|
||||
|
||||
// nodeIsWithinNearestK: helper, returns true if the node in question is within the nearest k from local node
|
||||
@ -276,16 +198,11 @@ func (rt *RoutingTable) nodeIsWithinNearestK(nodeID storj.NodeID) (bool, error)
|
||||
if err != nil {
|
||||
return false, RoutingErr.Wrap(err)
|
||||
}
|
||||
furthestIDWithinK, err := rt.determineFurthestIDWithinK(nodeIDs)
|
||||
if err != nil {
|
||||
return false, RoutingErr.New("could not determine furthest id within k: %s", err)
|
||||
}
|
||||
existingXor := xorTwoIds(furthestIDWithinK.Bytes(), rt.self.Id.Bytes())
|
||||
newXor := xorTwoIds(nodeID.Bytes(), rt.self.Id.Bytes())
|
||||
if bytes.Compare(newXor, existingXor) < 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
|
||||
furthestIDWithinK := rt.determineFurthestIDWithinK(nodeIDs)
|
||||
existingXor := xorNodeID(furthestIDWithinK, rt.self.Id)
|
||||
newXor := xorNodeID(nodeID, rt.self.Id)
|
||||
return newXor.Less(existingXor), nil
|
||||
}
|
||||
|
||||
// kadBucketContainsLocalNode returns true if the kbucket in question contains the local node
|
||||
@ -294,7 +211,7 @@ func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (bool, erro
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(queryID[:], bID[:]), nil
|
||||
return queryID == bID, nil
|
||||
}
|
||||
|
||||
// kadBucketHasRoom: helper, returns true if it has fewer than k nodes
|
||||
@ -396,15 +313,6 @@ func (rt *RoutingTable) getKBucketRange(bID bucketID) ([]bucketID, error) {
|
||||
return coords, nil
|
||||
}
|
||||
|
||||
// createFirstBucketID creates byte slice representing 11..11
|
||||
// bucket IDs are the highest address which that bucket contains
|
||||
func (rt *RoutingTable) createFirstBucketID() (id bucketID) {
|
||||
for i := 0; i < len(id); i++ {
|
||||
id[i] = 255
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// determineLeafDepth determines the level of the bucket id in question.
|
||||
// Eg level 0 means there is only 1 bucket, level 1 means the bucket has been split once, and so on
|
||||
func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) {
|
||||
@ -413,57 +321,13 @@ func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) {
|
||||
return -1, RoutingErr.New("could not get k bucket range %s", err)
|
||||
}
|
||||
smaller := bucketRange[0]
|
||||
diffBit, err := rt.determineDifferingBitIndex(bID, smaller)
|
||||
diffBit, err := determineDifferingBitIndex(bID, smaller)
|
||||
if err != nil {
|
||||
return diffBit + 1, RoutingErr.New("could not determine differing bit %s", err)
|
||||
}
|
||||
return diffBit + 1, nil
|
||||
}
|
||||
|
||||
// determineDifferingBitIndex: helper, returns the last bit differs starting from prefix to suffix
|
||||
func (rt *RoutingTable) determineDifferingBitIndex(bID, comparisonID bucketID) (int, error) {
|
||||
if bytes.Equal(bID[:], comparisonID[:]) {
|
||||
return -2, RoutingErr.New("compared two equivalent k bucket ids")
|
||||
}
|
||||
emptyBID := bucketID{}
|
||||
if bytes.Equal(comparisonID[:], emptyBID[:]) {
|
||||
comparisonID = rt.createFirstBucketID()
|
||||
}
|
||||
|
||||
var differingByteIndex int
|
||||
var differingByteXor byte
|
||||
xorArr := xorTwoIds(bID[:], comparisonID[:])
|
||||
|
||||
firstBID := rt.createFirstBucketID()
|
||||
if bytes.Equal(xorArr, firstBID[:]) {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
for j, v := range xorArr {
|
||||
if v != byte(0) {
|
||||
differingByteIndex = j
|
||||
differingByteXor = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
h := 0
|
||||
for ; h < 8; h++ {
|
||||
toggle := byte(1 << uint(h))
|
||||
tempXor := differingByteXor
|
||||
tempXor ^= toggle
|
||||
if tempXor < differingByteXor {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
bitInByteIndex := 7 - h
|
||||
byteIndex := differingByteIndex
|
||||
bitIndex := byteIndex*8 + bitInByteIndex
|
||||
|
||||
return bitIndex, nil
|
||||
}
|
||||
|
||||
// splitBucket: helper, returns the smaller of the two new bucket ids
|
||||
// the original bucket id becomes the greater of the 2 new
|
||||
func (rt *RoutingTable) splitBucket(bID bucketID, depth int) bucketID {
|
||||
|
@ -6,14 +6,12 @@ package kademlia
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/storj/internal/teststorj"
|
||||
@ -246,7 +244,7 @@ func TestUpdateNode(t *testing.T) {
|
||||
func TestRemoveNode(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromString("AA"))
|
||||
defer cleanup()
|
||||
kadBucketID := rt.createFirstBucketID()
|
||||
kadBucketID := firstBucketID
|
||||
node := teststorj.MockNode("BB")
|
||||
ok, err := rt.addNode(node)
|
||||
assert.True(t, ok)
|
||||
@ -272,19 +270,19 @@ func TestRemoveNode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateOrUpdateKBucket(t *testing.T) {
|
||||
id := teststorj.NodeIDFromBytes([]byte{255, 255})
|
||||
id := bucketID{255, 255}
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
err := rt.createOrUpdateKBucket(keyToBucketID(id.Bytes()), time.Now())
|
||||
err := rt.createOrUpdateKBucket(id, time.Now())
|
||||
assert.NoError(t, err)
|
||||
val, e := rt.kadBucketDB.Get(id.Bytes())
|
||||
val, e := rt.kadBucketDB.Get(id[:])
|
||||
assert.NotNil(t, val)
|
||||
assert.NoError(t, e)
|
||||
|
||||
}
|
||||
|
||||
func TestGetKBucketID(t *testing.T) {
|
||||
kadIDA := keyToBucketID([]byte{255, 255})
|
||||
kadIDA := bucketID{255, 255}
|
||||
nodeIDA := teststorj.NodeIDFromString("AA")
|
||||
rt, cleanup := createRoutingTable(t, nodeIDA)
|
||||
defer cleanup()
|
||||
@ -293,60 +291,8 @@ func TestGetKBucketID(t *testing.T) {
|
||||
assert.Equal(t, kadIDA[:2], keyA[:2])
|
||||
}
|
||||
|
||||
func TestXorTwoIds(t *testing.T) {
|
||||
x := xorTwoIds([]byte{191}, []byte{159})
|
||||
assert.Equal(t, []byte{32}, x) //00100000
|
||||
}
|
||||
|
||||
func TestSortByXOR(t *testing.T) {
|
||||
node1 := teststorj.NodeIDFromBytes([]byte{127, 255}) //xor 0
|
||||
rt, cleanup := createRoutingTable(t, node1)
|
||||
defer cleanup()
|
||||
node2 := teststorj.NodeIDFromBytes([]byte{143, 255}) //xor 240
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(node2.Bytes(), []byte("")))
|
||||
node3 := teststorj.NodeIDFromBytes([]byte{255, 255}) //xor 128
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(node3.Bytes(), []byte("")))
|
||||
node4 := teststorj.NodeIDFromBytes([]byte{191, 255}) //xor 192
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(node4.Bytes(), []byte("")))
|
||||
node5 := teststorj.NodeIDFromBytes([]byte{133, 255}) //xor 250
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(node5.Bytes(), []byte("")))
|
||||
nodes, err := rt.nodeBucketDB.List(nil, 0)
|
||||
assert.NoError(t, err)
|
||||
expectedNodes := storage.Keys{node1.Bytes(), node5.Bytes(), node2.Bytes(), node4.Bytes(), node3.Bytes()}
|
||||
assert.Equal(t, expectedNodes, nodes)
|
||||
sortByXOR(nodes, node1.Bytes())
|
||||
expectedSorted := storage.Keys{node1.Bytes(), node3.Bytes(), node4.Bytes(), node2.Bytes(), node5.Bytes()}
|
||||
assert.Equal(t, expectedSorted, nodes)
|
||||
nodes, err = rt.nodeBucketDB.List(nil, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedNodes, nodes)
|
||||
}
|
||||
|
||||
func BenchmarkSortByXOR(b *testing.B) {
|
||||
nodes := []storage.Key{}
|
||||
|
||||
newNodeID := func() storage.Key {
|
||||
id := make(storage.Key, 32)
|
||||
rand.Read(id[:])
|
||||
return id
|
||||
}
|
||||
|
||||
for k := 0; k < 1000; k++ {
|
||||
nodes = append(nodes, newNodeID())
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for m := 0; m < b.N; m++ {
|
||||
rand.Shuffle(len(nodes), func(i, k int) {
|
||||
nodes[i], nodes[k] = nodes[k], nodes[i]
|
||||
})
|
||||
|
||||
sortByXOR(nodes, newNodeID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineFurthestIDWithinK(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{127, 255}))
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{127, 255})
|
||||
defer cleanup()
|
||||
cases := []struct {
|
||||
testID string
|
||||
@ -379,8 +325,7 @@ func TestDetermineFurthestIDWithinK(t *testing.T) {
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(teststorj.NodeIDFromBytes(c.nodeID).Bytes(), []byte("")))
|
||||
nodes, err := rt.nodeBucketDB.List(nil, 0)
|
||||
assert.NoError(t, err)
|
||||
furthest, err := rt.determineFurthestIDWithinK(teststorj.NodeIDsFromBytes(nodes.ByteSlices()...))
|
||||
assert.NoError(t, err)
|
||||
furthest := rt.determineFurthestIDWithinK(teststorj.NodeIDsFromBytes(nodes.ByteSlices()...))
|
||||
fmt.Println(furthest.Bytes())
|
||||
assert.Equal(t, c.expectedFurthest, furthest[:2])
|
||||
})
|
||||
@ -388,50 +333,52 @@ func TestDetermineFurthestIDWithinK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNodeIsWithinNearestK(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{127, 255}))
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{127, 255})
|
||||
defer cleanup()
|
||||
rt.bucketSize = 2
|
||||
|
||||
cases := []struct {
|
||||
testID string
|
||||
nodeID []byte
|
||||
nodeID storj.NodeID
|
||||
closest bool
|
||||
}{
|
||||
{testID: "A",
|
||||
nodeID: []byte{127, 255},
|
||||
nodeID: storj.NodeID{127, 255},
|
||||
closest: true,
|
||||
},
|
||||
{testID: "B",
|
||||
nodeID: []byte{143, 255},
|
||||
nodeID: storj.NodeID{143, 255},
|
||||
closest: true,
|
||||
},
|
||||
{testID: "C",
|
||||
nodeID: []byte{255, 255},
|
||||
nodeID: storj.NodeID{255, 255},
|
||||
closest: true,
|
||||
},
|
||||
{testID: "D",
|
||||
nodeID: []byte{191, 255},
|
||||
nodeID: storj.NodeID{191, 255},
|
||||
closest: true,
|
||||
},
|
||||
{testID: "E",
|
||||
nodeID: []byte{133, 255},
|
||||
nodeID: storj.NodeID{133, 255},
|
||||
closest: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.testID, func(t *testing.T) {
|
||||
result, err := rt.nodeIsWithinNearestK(teststorj.NodeIDFromBytes(c.nodeID))
|
||||
result, err := rt.nodeIsWithinNearestK(c.nodeID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.closest, result)
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(teststorj.NodeIDFromBytes(c.nodeID).Bytes(), []byte("")))
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(c.nodeID.Bytes(), []byte("")))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKadBucketContainsLocalNode(t *testing.T) {
|
||||
nodeIDA := teststorj.NodeIDFromBytes([]byte{183, 255}) //[10110111, 1111111]
|
||||
nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
|
||||
rt, cleanup := createRoutingTable(t, nodeIDA)
|
||||
defer cleanup()
|
||||
kadIDA := rt.createFirstBucketID()
|
||||
kadIDA := firstBucketID
|
||||
var kadIDB bucketID
|
||||
copy(kadIDB[:], kadIDA[:])
|
||||
kadIDB[0] = 127
|
||||
@ -447,15 +394,15 @@ func TestKadBucketContainsLocalNode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestKadBucketHasRoom(t *testing.T) {
|
||||
node1 := teststorj.NodeIDFromBytes([]byte{255, 255})
|
||||
node1 := storj.NodeID{255, 255}
|
||||
rt, cleanup := createRoutingTable(t, node1)
|
||||
defer cleanup()
|
||||
kadIDA := rt.createFirstBucketID()
|
||||
node2 := teststorj.NodeIDFromBytes([]byte{191, 255})
|
||||
node3 := teststorj.NodeIDFromBytes([]byte{127, 255})
|
||||
node4 := teststorj.NodeIDFromBytes([]byte{63, 255})
|
||||
node5 := teststorj.NodeIDFromBytes([]byte{159, 255})
|
||||
node6 := teststorj.NodeIDFromBytes([]byte{0, 127})
|
||||
kadIDA := firstBucketID
|
||||
node2 := storj.NodeID{191, 255}
|
||||
node3 := storj.NodeID{127, 255}
|
||||
node4 := storj.NodeID{63, 255}
|
||||
node5 := storj.NodeID{159, 255}
|
||||
node6 := storj.NodeID{0, 127}
|
||||
resultA, err := rt.kadBucketHasRoom(kadIDA)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, resultA)
|
||||
@ -470,18 +417,18 @@ func TestKadBucketHasRoom(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetNodeIDsWithinKBucket(t *testing.T) {
|
||||
nodeIDA := teststorj.NodeIDFromBytes([]byte{183, 255}) //[10110111, 1111111]
|
||||
nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
|
||||
rt, cleanup := createRoutingTable(t, nodeIDA)
|
||||
defer cleanup()
|
||||
kadIDA := rt.createFirstBucketID()
|
||||
kadIDA := firstBucketID
|
||||
var kadIDB bucketID
|
||||
copy(kadIDB[:], kadIDA[:])
|
||||
kadIDB[0] = 127
|
||||
now := time.Now()
|
||||
assert.NoError(t, rt.createOrUpdateKBucket(kadIDB, now))
|
||||
|
||||
nodeIDB := teststorj.NodeIDFromBytes([]byte{111, 255}) //[01101111, 1111111]
|
||||
nodeIDC := teststorj.NodeIDFromBytes([]byte{47, 255}) //[00101111, 1111111]
|
||||
nodeIDB := storj.NodeID{111, 255} //[01101111, 1111111]
|
||||
nodeIDC := storj.NodeID{47, 255} //[00101111, 1111111]
|
||||
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(nodeIDB.Bytes(), []byte("")))
|
||||
assert.NoError(t, rt.nodeBucketDB.Put(nodeIDC.Bytes(), []byte("")))
|
||||
@ -567,7 +514,7 @@ func TestUnmarshalNodes(t *testing.T) {
|
||||
func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
|
||||
nodeA := teststorj.MockNode("AA")
|
||||
rt, cleanup := createRoutingTable(t, nodeA.Id)
|
||||
bucketID := rt.createFirstBucketID()
|
||||
bucketID := firstBucketID
|
||||
defer cleanup()
|
||||
nodeB := teststorj.MockNode("BB")
|
||||
nodeC := teststorj.MockNode("CC")
|
||||
@ -587,9 +534,9 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
|
||||
func TestGetKBucketRange(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
idA := teststorj.NodeIDFromBytes([]byte{255, 255})
|
||||
idB := teststorj.NodeIDFromBytes([]byte{127, 255})
|
||||
idC := teststorj.NodeIDFromBytes([]byte{63, 255})
|
||||
idA := storj.NodeID{255, 255}
|
||||
idB := storj.NodeID{127, 255}
|
||||
idC := storj.NodeID{63, 255}
|
||||
assert.NoError(t, rt.kadBucketDB.Put(idA.Bytes(), []byte("")))
|
||||
assert.NoError(t, rt.kadBucketDB.Put(idB.Bytes(), []byte("")))
|
||||
assert.NoError(t, rt.kadBucketDB.Put(idC.Bytes(), []byte("")))
|
||||
@ -622,14 +569,6 @@ func TestGetKBucketRange(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFirstBucketID(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
x := rt.createFirstBucketID()
|
||||
expected := []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}
|
||||
assert.Equal(t, x[:], expected)
|
||||
}
|
||||
|
||||
func TestBucketIDZeroValue(t *testing.T) {
|
||||
// rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
// defer cleanup()
|
||||
@ -641,9 +580,10 @@ func TestBucketIDZeroValue(t *testing.T) {
|
||||
func TestDetermineLeafDepth(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
idA := teststorj.NodeIDFromBytes([]byte{255, 255})
|
||||
idB := teststorj.NodeIDFromBytes([]byte{127, 255})
|
||||
idC := teststorj.NodeIDFromBytes([]byte{63, 255})
|
||||
idA, idB, idC := storj.NodeID(firstBucketID), storj.NodeID(firstBucketID), storj.NodeID(firstBucketID)
|
||||
idA[0] = 255
|
||||
idB[0] = 127
|
||||
idC[0] = 63
|
||||
|
||||
cases := []struct {
|
||||
testID string
|
||||
@ -689,117 +629,13 @@ func TestDetermineLeafDepth(t *testing.T) {
|
||||
for _, c := range cases {
|
||||
t.Run(c.testID, func(t *testing.T) {
|
||||
c.addNode()
|
||||
d, err := rt.determineLeafDepth(keyToBucketID(c.id.Bytes()))
|
||||
d, err := rt.determineLeafDepth(bucketID(c.id))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.depth, d)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func padBucketID(b []byte, p byte) (bID bucketID) {
|
||||
for i := range bID {
|
||||
if len(b) > i {
|
||||
bID[i] = b[i]
|
||||
continue
|
||||
}
|
||||
bID[i] = p
|
||||
}
|
||||
return bID
|
||||
}
|
||||
|
||||
func TestDetermineDifferingBitIndex(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
cases := []struct {
|
||||
testID string
|
||||
bucketID bucketID
|
||||
key bucketID
|
||||
expected int
|
||||
err *errs.Class
|
||||
}{
|
||||
{testID: "A",
|
||||
bucketID: padBucketID([]byte{191, 255}, 255),
|
||||
key: padBucketID([]byte{255, 255}, 255),
|
||||
expected: 1,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "B",
|
||||
bucketID: padBucketID([]byte{255, 255}, 255),
|
||||
key: padBucketID([]byte{191, 255}, 255),
|
||||
expected: 1,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "C",
|
||||
bucketID: padBucketID([]byte{95, 255}, 255),
|
||||
key: padBucketID([]byte{127, 255}, 255),
|
||||
expected: 2,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "D",
|
||||
bucketID: padBucketID([]byte{95, 255}, 255),
|
||||
key: padBucketID([]byte{79, 255}, 255),
|
||||
expected: 3,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "E",
|
||||
bucketID: padBucketID([]byte{95, 255}, 255),
|
||||
key: padBucketID([]byte{63, 255}, 255),
|
||||
expected: 2,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "F",
|
||||
bucketID: padBucketID([]byte{95, 255}, 255),
|
||||
key: padBucketID([]byte{79, 255}, 255),
|
||||
expected: 3,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "G",
|
||||
bucketID: padBucketID([]byte{255, 255}, 255),
|
||||
key: padBucketID([]byte{255, 255}, 255),
|
||||
expected: -2,
|
||||
err: &RoutingErr,
|
||||
},
|
||||
{testID: "H",
|
||||
bucketID: padBucketID([]byte{255, 255}, 255),
|
||||
key: padBucketID([]byte{0, 0}, 0),
|
||||
expected: -1,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "I",
|
||||
bucketID: padBucketID([]byte{127, 255}, 255),
|
||||
key: padBucketID([]byte{0, 0}, 0),
|
||||
expected: 0,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "J",
|
||||
bucketID: padBucketID([]byte{63, 255}, 255),
|
||||
key: padBucketID([]byte{0, 0}, 0),
|
||||
expected: 1,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "K",
|
||||
bucketID: padBucketID([]byte{31, 255}, 255),
|
||||
key: padBucketID([]byte{0, 0}, 0),
|
||||
expected: 2,
|
||||
err: nil,
|
||||
},
|
||||
{testID: "L",
|
||||
bucketID: padBucketID([]byte{95, 255}, 255),
|
||||
key: padBucketID([]byte{63, 255}, 255),
|
||||
expected: 2,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.testID, func(t *testing.T) {
|
||||
diff, err := rt.determineDifferingBitIndex(c.bucketID, c.key)
|
||||
assertErrClass(t, c.err, err)
|
||||
assert.Equal(t, c.expected, diff)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBucket(t *testing.T) {
|
||||
rt, cleanup := createRoutingTable(t, storj.NodeID{})
|
||||
defer cleanup()
|
||||
@ -847,12 +683,3 @@ func TestSplitBucket(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertErrClass(t *testing.T, class *errs.Class, err error) {
|
||||
t.Helper()
|
||||
if class != nil {
|
||||
assert.True(t, class.Has(err))
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
103
pkg/kademlia/utils.go
Normal file
103
pkg/kademlia/utils.go
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package kademlia
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"sort"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/pkg/storj"
|
||||
"storj.io/storj/storage"
|
||||
)
|
||||
|
||||
func cloneNodeIDs(ids storj.NodeIDList) storj.NodeIDList {
|
||||
clone := make(storj.NodeIDList, len(ids))
|
||||
copy(clone, ids)
|
||||
return clone
|
||||
}
|
||||
|
||||
// compareByXor compares left, right xorred by reference
|
||||
func compareByXor(left, right, reference storj.NodeID) int {
|
||||
for i, r := range reference {
|
||||
a, b := left[i]^r, right[i]^r
|
||||
if a != b {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sortByXOR(nodeIDs storj.NodeIDList, ref storj.NodeID) {
|
||||
sort.Slice(nodeIDs, func(i, k int) bool {
|
||||
return compareByXor(nodeIDs[i], nodeIDs[k], ref) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func keysToNodeIDs(keys storage.Keys) (ids storj.NodeIDList, err error) {
|
||||
var idErrs []error
|
||||
for _, k := range keys {
|
||||
id, err := storj.NodeIDFromBytes(k[:])
|
||||
if err != nil {
|
||||
idErrs = append(idErrs, err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := errs.Combine(idErrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func keyToBucketID(key storage.Key) (bID bucketID) {
|
||||
copy(bID[:], key)
|
||||
return bID
|
||||
}
|
||||
|
||||
// xorNodeID returns the xor of each byte in NodeID
|
||||
func xorNodeID(a, b storj.NodeID) storj.NodeID {
|
||||
r := storj.NodeID{}
|
||||
for i, av := range a {
|
||||
r[i] = av ^ b[i]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// xorBucketID returns the xor of each byte in bucketID
|
||||
func xorBucketID(a, b bucketID) bucketID {
|
||||
r := bucketID{}
|
||||
for i, av := range a {
|
||||
r[i] = av ^ b[i]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// determineDifferingBitIndex: helper, returns the last bit differs starting from prefix to suffix
|
||||
func determineDifferingBitIndex(bID, comparisonID bucketID) (int, error) {
|
||||
if bID == comparisonID {
|
||||
return -2, RoutingErr.New("compared two equivalent k bucket ids")
|
||||
}
|
||||
|
||||
if comparisonID == emptyBucketID {
|
||||
comparisonID = firstBucketID
|
||||
}
|
||||
|
||||
xorID := xorBucketID(bID, comparisonID)
|
||||
if xorID == firstBucketID {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
for i, v := range xorID {
|
||||
if v != 0 {
|
||||
return i*8 + 7 - bits.TrailingZeros8(v), nil
|
||||
}
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
127
pkg/kademlia/utils_test.go
Normal file
127
pkg/kademlia/utils_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package kademlia
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"storj.io/storj/pkg/storj"
|
||||
)
|
||||
|
||||
func TestSortByXOR(t *testing.T) {
|
||||
n1 := storj.NodeID{127, 255} //xor 0
|
||||
n2 := storj.NodeID{143, 255} //xor 240
|
||||
n3 := storj.NodeID{255, 255} //xor 128
|
||||
n4 := storj.NodeID{191, 255} //xor 192
|
||||
n5 := storj.NodeID{133, 255} //xor 250
|
||||
unsorted := storj.NodeIDList{n1, n5, n2, n4, n3}
|
||||
sortByXOR(unsorted, n1)
|
||||
sorted := storj.NodeIDList{n1, n3, n4, n2, n5}
|
||||
assert.Equal(t, sorted, unsorted)
|
||||
}
|
||||
|
||||
func BenchmarkSortByXOR(b *testing.B) {
|
||||
newNodeID := func() storj.NodeID {
|
||||
var id storj.NodeID
|
||||
rand.Read(id[:])
|
||||
return id
|
||||
}
|
||||
|
||||
nodes := []storj.NodeID{}
|
||||
for k := 0; k < 1000; k++ {
|
||||
nodes = append(nodes, newNodeID())
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for m := 0; m < b.N; m++ {
|
||||
rand.Shuffle(len(nodes), func(i, k int) {
|
||||
nodes[i], nodes[k] = nodes[k], nodes[i]
|
||||
})
|
||||
sortByXOR(nodes, newNodeID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineDifferingBitIndex(t *testing.T) {
|
||||
filledID := func(a byte) bucketID {
|
||||
id := firstBucketID
|
||||
id[0] = a
|
||||
return id
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
bucketID bucketID
|
||||
key bucketID
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
bucketID: filledID(191),
|
||||
key: filledID(255),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(255),
|
||||
key: filledID(191),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(95),
|
||||
key: filledID(127),
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(95),
|
||||
key: filledID(79),
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(95),
|
||||
key: filledID(63),
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(95),
|
||||
key: filledID(79),
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(255),
|
||||
key: bucketID{},
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(127),
|
||||
key: bucketID{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(63),
|
||||
key: bucketID{},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(31),
|
||||
key: bucketID{},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
bucketID: filledID(95),
|
||||
key: filledID(63),
|
||||
expected: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
t.Logf("#%d. bucketID:%v key:%v\n", i, c.bucketID, c.key)
|
||||
diff, err := determineDifferingBitIndex(c.bucketID, c.key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.expected, diff)
|
||||
}
|
||||
|
||||
diff, err := determineDifferingBitIndex(filledID(255), filledID(255))
|
||||
assert.True(t, RoutingErr.Has(err))
|
||||
assert.Equal(t, diff, -2)
|
||||
}
|
Loading…
Reference in New Issue
Block a user