pkg/kademlia: simplify code (#958)

This commit is contained in:
Egon Elbre 2019-01-02 20:57:11 +02:00 committed by GitHub
parent a2fa5c4c5a
commit 26c2564bd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 308 additions and 380 deletions

View File

@ -201,7 +201,7 @@ func TestRefresh(t *testing.T) {
//turn back time for only bucket //turn back time for only bucket
rt := k.routingTable rt := k.routingTable
now := time.Now().UTC() now := time.Now().UTC()
bID := rt.createFirstBucketID() //always exists bID := firstBucketID //always exists
err := rt.SetBucketTimestamp(bID[:], now.Add(-2*time.Hour)) err := rt.SetBucketTimestamp(bID[:], now.Add(-2*time.Hour))
assert.NoError(t, err) assert.NoError(t, err)
//refresh should call FindNode, updating the time //refresh should call FindNode, updating the time

View File

@ -231,12 +231,3 @@ func (queue *discoveryQueue) Len() int {
return len(queue.items) return len(queue.items)
} }
// xorNodeID returns the xor of each byte in NodeID
func xorNodeID(a, b storj.NodeID) storj.NodeID {
r := storj.NodeID{}
for i, av := range a {
r[i] = av ^ b[i]
}
return r
}

View File

@ -10,16 +10,18 @@ import (
"storj.io/storj/internal/teststorj" "storj.io/storj/internal/teststorj"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
) )
func TestAddToReplacementCache(t *testing.T) { func TestAddToReplacementCache(t *testing.T) {
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{244, 255})) rt, cleanup := createRoutingTable(t, storj.NodeID{244, 255})
defer cleanup() defer cleanup()
kadBucketID := keyToBucketID(teststorj.NodeIDFromBytes([]byte{255, 255}).Bytes())
kadBucketID := bucketID{255, 255}
node1 := teststorj.MockNode(string([]byte{233, 255})) node1 := teststorj.MockNode(string([]byte{233, 255}))
rt.addToReplacementCache(kadBucketID, node1) rt.addToReplacementCache(kadBucketID, node1)
assert.Equal(t, []*pb.Node{node1}, rt.replacementCache[kadBucketID]) assert.Equal(t, []*pb.Node{node1}, rt.replacementCache[kadBucketID])
kadBucketID2 := keyToBucketID(teststorj.NodeIDFromBytes([]byte{127, 255}).Bytes()) kadBucketID2 := bucketID{127, 255}
node2 := teststorj.MockNode(string([]byte{100, 255})) node2 := teststorj.MockNode(string([]byte{100, 255}))
node3 := teststorj.MockNode(string([]byte{90, 255})) node3 := teststorj.MockNode(string([]byte{90, 255}))
node4 := teststorj.MockNode(string([]byte{80, 255})) node4 := teststorj.MockNode(string([]byte{80, 255}))

View File

@ -32,6 +32,20 @@ var RoutingErr = errs.Class("routing table error")
// Bucket IDs exist in the same address space as node IDs // Bucket IDs exist in the same address space as node IDs
type bucketID [len(storj.NodeID{})]byte type bucketID [len(storj.NodeID{})]byte
var firstBucketID = bucketID{
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
}
var emptyBucketID = bucketID{}
// RoutingTable implements the RoutingTable interface // RoutingTable implements the RoutingTable interface
type RoutingTable struct { type RoutingTable struct {
log *zap.Logger log *zap.Logger
@ -130,14 +144,14 @@ func (rt *RoutingTable) FindNear(id storj.NodeID, limit int) (nodes []*pb.Node,
if err != nil { if err != nil {
return nodes, RoutingErr.New("could not get node ids %s", err) return nodes, RoutingErr.New("could not get node ids %s", err)
} }
sortByXOR(nodeIDsKeys, id.Bytes())
if len(nodeIDsKeys) >= limit {
nodeIDsKeys = nodeIDsKeys[:limit]
}
nodeIDs, err := storj.NodeIDsFromBytes(nodeIDsKeys.ByteSlices()) nodeIDs, err := storj.NodeIDsFromBytes(nodeIDsKeys.ByteSlices())
if err != nil { if err != nil {
return nodes, RoutingErr.Wrap(err) return nodes, RoutingErr.Wrap(err)
} }
sortByXOR(nodeIDs, id)
if len(nodeIDs) >= limit {
nodeIDs = nodeIDs[:limit]
}
nodes, err = rt.getNodesFromIDsBytes(nodeIDs) nodes, err = rt.getNodesFromIDsBytes(nodeIDs)
if err != nil { if err != nil {

View File

@ -6,11 +6,9 @@ package kademlia
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"sort"
"time" "time"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/zeebo/errs"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj" "storj.io/storj/pkg/storj"
@ -23,10 +21,9 @@ import (
func (rt *RoutingTable) addNode(node *pb.Node) (bool, error) { func (rt *RoutingTable) addNode(node *pb.Node) (bool, error) {
rt.mutex.Lock() rt.mutex.Lock()
defer rt.mutex.Unlock() defer rt.mutex.Unlock()
nodeIDBytes := node.Id.Bytes()
if bytes.Equal(nodeIDBytes, rt.self.Id.Bytes()) { if node.Id == rt.self.Id {
err := rt.createOrUpdateKBucket(rt.createFirstBucketID(), time.Now()) err := rt.createOrUpdateKBucket(firstBucketID, time.Now())
if err != nil { if err != nil {
return false, RoutingErr.New("could not create initial K bucket: %s", err) return false, RoutingErr.New("could not create initial K bucket: %s", err)
} }
@ -177,89 +174,14 @@ func (rt *RoutingTable) getKBucketID(nodeID storj.NodeID) (bucketID, error) {
return bucketID{}, RoutingErr.New("could not find k bucket") return bucketID{}, RoutingErr.New("could not find k bucket")
} }
// compareByXor compares left, right xorred by reference
func compareByXor(left, right, reference storage.Key) int {
n := len(reference)
if n > len(left) {
n = len(left)
}
if n > len(right) {
n = len(right)
}
left = left[:n]
right = right[:n]
reference = reference[:n]
for i, r := range reference {
a, b := left[i]^r, right[i]^r
if a != b {
if a < b {
return -1
}
return 1
}
}
return 0
}
func sortByXOR(nodeIDs storage.Keys, ref storage.Key) {
sort.Slice(nodeIDs, func(i, k int) bool {
return compareByXor(nodeIDs[i], nodeIDs[k], ref) < 0
})
}
func nodeIDsToKeys(ids storj.NodeIDList) (nodeIDKeys storage.Keys) {
for _, n := range ids {
nodeIDKeys = append(nodeIDKeys, n.Bytes())
}
return nodeIDKeys
}
func keysToNodeIDs(keys storage.Keys) (ids storj.NodeIDList, err error) {
var idErrs []error
for _, k := range keys {
id, err := storj.NodeIDFromBytes(k[:])
if err != nil {
idErrs = append(idErrs, err)
}
ids = append(ids, id)
}
if err := errs.Combine(idErrs...); err != nil {
return nil, err
}
return ids, nil
}
func keyToBucketID(key storage.Key) (bID bucketID) {
copy(bID[:], key)
return bID
}
// determineFurthestIDWithinK: helper, determines the furthest node within the k closest to local node // determineFurthestIDWithinK: helper, determines the furthest node within the k closest to local node
func (rt *RoutingTable) determineFurthestIDWithinK(nodeIDs storj.NodeIDList) (storj.NodeID, error) { func (rt *RoutingTable) determineFurthestIDWithinK(nodeIDs storj.NodeIDList) storj.NodeID {
nodeIDKeys := nodeIDsToKeys(nodeIDs) nodeIDs = cloneNodeIDs(nodeIDs)
sortByXOR(nodeIDKeys, rt.self.Id.Bytes()) sortByXOR(nodeIDs, rt.self.Id)
if len(nodeIDs) < rt.bucketSize+1 { //adding 1 since we're not including local node in closest k if len(nodeIDs) < rt.bucketSize+1 { //adding 1 since we're not including local node in closest k
return storj.NodeIDFromBytes(nodeIDKeys[len(nodeIDKeys)-1]) return nodeIDs[len(nodeIDs)-1]
} }
return storj.NodeIDFromBytes(nodeIDKeys[rt.bucketSize]) return nodeIDs[rt.bucketSize]
}
// xorTwoIds: helper, finds the xor distance between two byte slices
func xorTwoIds(id, comparisonID []byte) []byte {
var xorArr []byte
s := len(id)
if s > len(comparisonID) {
s = len(comparisonID)
}
for i := 0; i < s; i++ {
xor := id[i] ^ comparisonID[i]
xorArr = append(xorArr, xor)
}
return xorArr
} }
// nodeIsWithinNearestK: helper, returns true if the node in question is within the nearest k from local node // nodeIsWithinNearestK: helper, returns true if the node in question is within the nearest k from local node
@ -276,16 +198,11 @@ func (rt *RoutingTable) nodeIsWithinNearestK(nodeID storj.NodeID) (bool, error)
if err != nil { if err != nil {
return false, RoutingErr.Wrap(err) return false, RoutingErr.Wrap(err)
} }
furthestIDWithinK, err := rt.determineFurthestIDWithinK(nodeIDs)
if err != nil { furthestIDWithinK := rt.determineFurthestIDWithinK(nodeIDs)
return false, RoutingErr.New("could not determine furthest id within k: %s", err) existingXor := xorNodeID(furthestIDWithinK, rt.self.Id)
} newXor := xorNodeID(nodeID, rt.self.Id)
existingXor := xorTwoIds(furthestIDWithinK.Bytes(), rt.self.Id.Bytes()) return newXor.Less(existingXor), nil
newXor := xorTwoIds(nodeID.Bytes(), rt.self.Id.Bytes())
if bytes.Compare(newXor, existingXor) < 0 {
return true, nil
}
return false, nil
} }
// kadBucketContainsLocalNode returns true if the kbucket in question contains the local node // kadBucketContainsLocalNode returns true if the kbucket in question contains the local node
@ -294,7 +211,7 @@ func (rt *RoutingTable) kadBucketContainsLocalNode(queryID bucketID) (bool, erro
if err != nil { if err != nil {
return false, err return false, err
} }
return bytes.Equal(queryID[:], bID[:]), nil return queryID == bID, nil
} }
// kadBucketHasRoom: helper, returns true if it has fewer than k nodes // kadBucketHasRoom: helper, returns true if it has fewer than k nodes
@ -396,15 +313,6 @@ func (rt *RoutingTable) getKBucketRange(bID bucketID) ([]bucketID, error) {
return coords, nil return coords, nil
} }
// createFirstBucketID creates byte slice representing 11..11
// bucket IDs are the highest address which that bucket contains
func (rt *RoutingTable) createFirstBucketID() (id bucketID) {
for i := 0; i < len(id); i++ {
id[i] = 255
}
return id
}
// determineLeafDepth determines the level of the bucket id in question. // determineLeafDepth determines the level of the bucket id in question.
// Eg level 0 means there is only 1 bucket, level 1 means the bucket has been split once, and so on // Eg level 0 means there is only 1 bucket, level 1 means the bucket has been split once, and so on
func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) { func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) {
@ -413,57 +321,13 @@ func (rt *RoutingTable) determineLeafDepth(bID bucketID) (int, error) {
return -1, RoutingErr.New("could not get k bucket range %s", err) return -1, RoutingErr.New("could not get k bucket range %s", err)
} }
smaller := bucketRange[0] smaller := bucketRange[0]
diffBit, err := rt.determineDifferingBitIndex(bID, smaller) diffBit, err := determineDifferingBitIndex(bID, smaller)
if err != nil { if err != nil {
return diffBit + 1, RoutingErr.New("could not determine differing bit %s", err) return diffBit + 1, RoutingErr.New("could not determine differing bit %s", err)
} }
return diffBit + 1, nil return diffBit + 1, nil
} }
// determineDifferingBitIndex: helper, returns the last bit differs starting from prefix to suffix
func (rt *RoutingTable) determineDifferingBitIndex(bID, comparisonID bucketID) (int, error) {
if bytes.Equal(bID[:], comparisonID[:]) {
return -2, RoutingErr.New("compared two equivalent k bucket ids")
}
emptyBID := bucketID{}
if bytes.Equal(comparisonID[:], emptyBID[:]) {
comparisonID = rt.createFirstBucketID()
}
var differingByteIndex int
var differingByteXor byte
xorArr := xorTwoIds(bID[:], comparisonID[:])
firstBID := rt.createFirstBucketID()
if bytes.Equal(xorArr, firstBID[:]) {
return -1, nil
}
for j, v := range xorArr {
if v != byte(0) {
differingByteIndex = j
differingByteXor = v
break
}
}
h := 0
for ; h < 8; h++ {
toggle := byte(1 << uint(h))
tempXor := differingByteXor
tempXor ^= toggle
if tempXor < differingByteXor {
break
}
}
bitInByteIndex := 7 - h
byteIndex := differingByteIndex
bitIndex := byteIndex*8 + bitInByteIndex
return bitIndex, nil
}
// splitBucket: helper, returns the smaller of the two new bucket ids // splitBucket: helper, returns the smaller of the two new bucket ids
// the original bucket id becomes the greater of the 2 new // the original bucket id becomes the greater of the 2 new
func (rt *RoutingTable) splitBucket(bID bucketID, depth int) bucketID { func (rt *RoutingTable) splitBucket(bID bucketID, depth int) bucketID {

View File

@ -6,14 +6,12 @@ package kademlia
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math/rand"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"storj.io/storj/internal/teststorj" "storj.io/storj/internal/teststorj"
@ -246,7 +244,7 @@ func TestUpdateNode(t *testing.T) {
func TestRemoveNode(t *testing.T) { func TestRemoveNode(t *testing.T) {
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromString("AA")) rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromString("AA"))
defer cleanup() defer cleanup()
kadBucketID := rt.createFirstBucketID() kadBucketID := firstBucketID
node := teststorj.MockNode("BB") node := teststorj.MockNode("BB")
ok, err := rt.addNode(node) ok, err := rt.addNode(node)
assert.True(t, ok) assert.True(t, ok)
@ -272,19 +270,19 @@ func TestRemoveNode(t *testing.T) {
} }
func TestCreateOrUpdateKBucket(t *testing.T) { func TestCreateOrUpdateKBucket(t *testing.T) {
id := teststorj.NodeIDFromBytes([]byte{255, 255}) id := bucketID{255, 255}
rt, cleanup := createRoutingTable(t, storj.NodeID{}) rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup() defer cleanup()
err := rt.createOrUpdateKBucket(keyToBucketID(id.Bytes()), time.Now()) err := rt.createOrUpdateKBucket(id, time.Now())
assert.NoError(t, err) assert.NoError(t, err)
val, e := rt.kadBucketDB.Get(id.Bytes()) val, e := rt.kadBucketDB.Get(id[:])
assert.NotNil(t, val) assert.NotNil(t, val)
assert.NoError(t, e) assert.NoError(t, e)
} }
func TestGetKBucketID(t *testing.T) { func TestGetKBucketID(t *testing.T) {
kadIDA := keyToBucketID([]byte{255, 255}) kadIDA := bucketID{255, 255}
nodeIDA := teststorj.NodeIDFromString("AA") nodeIDA := teststorj.NodeIDFromString("AA")
rt, cleanup := createRoutingTable(t, nodeIDA) rt, cleanup := createRoutingTable(t, nodeIDA)
defer cleanup() defer cleanup()
@ -293,60 +291,8 @@ func TestGetKBucketID(t *testing.T) {
assert.Equal(t, kadIDA[:2], keyA[:2]) assert.Equal(t, kadIDA[:2], keyA[:2])
} }
func TestXorTwoIds(t *testing.T) {
x := xorTwoIds([]byte{191}, []byte{159})
assert.Equal(t, []byte{32}, x) //00100000
}
func TestSortByXOR(t *testing.T) {
node1 := teststorj.NodeIDFromBytes([]byte{127, 255}) //xor 0
rt, cleanup := createRoutingTable(t, node1)
defer cleanup()
node2 := teststorj.NodeIDFromBytes([]byte{143, 255}) //xor 240
assert.NoError(t, rt.nodeBucketDB.Put(node2.Bytes(), []byte("")))
node3 := teststorj.NodeIDFromBytes([]byte{255, 255}) //xor 128
assert.NoError(t, rt.nodeBucketDB.Put(node3.Bytes(), []byte("")))
node4 := teststorj.NodeIDFromBytes([]byte{191, 255}) //xor 192
assert.NoError(t, rt.nodeBucketDB.Put(node4.Bytes(), []byte("")))
node5 := teststorj.NodeIDFromBytes([]byte{133, 255}) //xor 250
assert.NoError(t, rt.nodeBucketDB.Put(node5.Bytes(), []byte("")))
nodes, err := rt.nodeBucketDB.List(nil, 0)
assert.NoError(t, err)
expectedNodes := storage.Keys{node1.Bytes(), node5.Bytes(), node2.Bytes(), node4.Bytes(), node3.Bytes()}
assert.Equal(t, expectedNodes, nodes)
sortByXOR(nodes, node1.Bytes())
expectedSorted := storage.Keys{node1.Bytes(), node3.Bytes(), node4.Bytes(), node2.Bytes(), node5.Bytes()}
assert.Equal(t, expectedSorted, nodes)
nodes, err = rt.nodeBucketDB.List(nil, 0)
assert.NoError(t, err)
assert.Equal(t, expectedNodes, nodes)
}
func BenchmarkSortByXOR(b *testing.B) {
nodes := []storage.Key{}
newNodeID := func() storage.Key {
id := make(storage.Key, 32)
rand.Read(id[:])
return id
}
for k := 0; k < 1000; k++ {
nodes = append(nodes, newNodeID())
}
b.ResetTimer()
for m := 0; m < b.N; m++ {
rand.Shuffle(len(nodes), func(i, k int) {
nodes[i], nodes[k] = nodes[k], nodes[i]
})
sortByXOR(nodes, newNodeID())
}
}
func TestDetermineFurthestIDWithinK(t *testing.T) { func TestDetermineFurthestIDWithinK(t *testing.T) {
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{127, 255})) rt, cleanup := createRoutingTable(t, storj.NodeID{127, 255})
defer cleanup() defer cleanup()
cases := []struct { cases := []struct {
testID string testID string
@ -379,8 +325,7 @@ func TestDetermineFurthestIDWithinK(t *testing.T) {
assert.NoError(t, rt.nodeBucketDB.Put(teststorj.NodeIDFromBytes(c.nodeID).Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(teststorj.NodeIDFromBytes(c.nodeID).Bytes(), []byte("")))
nodes, err := rt.nodeBucketDB.List(nil, 0) nodes, err := rt.nodeBucketDB.List(nil, 0)
assert.NoError(t, err) assert.NoError(t, err)
furthest, err := rt.determineFurthestIDWithinK(teststorj.NodeIDsFromBytes(nodes.ByteSlices()...)) furthest := rt.determineFurthestIDWithinK(teststorj.NodeIDsFromBytes(nodes.ByteSlices()...))
assert.NoError(t, err)
fmt.Println(furthest.Bytes()) fmt.Println(furthest.Bytes())
assert.Equal(t, c.expectedFurthest, furthest[:2]) assert.Equal(t, c.expectedFurthest, furthest[:2])
}) })
@ -388,50 +333,52 @@ func TestDetermineFurthestIDWithinK(t *testing.T) {
} }
func TestNodeIsWithinNearestK(t *testing.T) { func TestNodeIsWithinNearestK(t *testing.T) {
rt, cleanup := createRoutingTable(t, teststorj.NodeIDFromBytes([]byte{127, 255})) rt, cleanup := createRoutingTable(t, storj.NodeID{127, 255})
defer cleanup() defer cleanup()
rt.bucketSize = 2 rt.bucketSize = 2
cases := []struct { cases := []struct {
testID string testID string
nodeID []byte nodeID storj.NodeID
closest bool closest bool
}{ }{
{testID: "A", {testID: "A",
nodeID: []byte{127, 255}, nodeID: storj.NodeID{127, 255},
closest: true, closest: true,
}, },
{testID: "B", {testID: "B",
nodeID: []byte{143, 255}, nodeID: storj.NodeID{143, 255},
closest: true, closest: true,
}, },
{testID: "C", {testID: "C",
nodeID: []byte{255, 255}, nodeID: storj.NodeID{255, 255},
closest: true, closest: true,
}, },
{testID: "D", {testID: "D",
nodeID: []byte{191, 255}, nodeID: storj.NodeID{191, 255},
closest: true, closest: true,
}, },
{testID: "E", {testID: "E",
nodeID: []byte{133, 255}, nodeID: storj.NodeID{133, 255},
closest: false, closest: false,
}, },
} }
for _, c := range cases { for _, c := range cases {
t.Run(c.testID, func(t *testing.T) { t.Run(c.testID, func(t *testing.T) {
result, err := rt.nodeIsWithinNearestK(teststorj.NodeIDFromBytes(c.nodeID)) result, err := rt.nodeIsWithinNearestK(c.nodeID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.closest, result) assert.Equal(t, c.closest, result)
assert.NoError(t, rt.nodeBucketDB.Put(teststorj.NodeIDFromBytes(c.nodeID).Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(c.nodeID.Bytes(), []byte("")))
}) })
} }
} }
func TestKadBucketContainsLocalNode(t *testing.T) { func TestKadBucketContainsLocalNode(t *testing.T) {
nodeIDA := teststorj.NodeIDFromBytes([]byte{183, 255}) //[10110111, 1111111] nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
rt, cleanup := createRoutingTable(t, nodeIDA) rt, cleanup := createRoutingTable(t, nodeIDA)
defer cleanup() defer cleanup()
kadIDA := rt.createFirstBucketID() kadIDA := firstBucketID
var kadIDB bucketID var kadIDB bucketID
copy(kadIDB[:], kadIDA[:]) copy(kadIDB[:], kadIDA[:])
kadIDB[0] = 127 kadIDB[0] = 127
@ -447,15 +394,15 @@ func TestKadBucketContainsLocalNode(t *testing.T) {
} }
func TestKadBucketHasRoom(t *testing.T) { func TestKadBucketHasRoom(t *testing.T) {
node1 := teststorj.NodeIDFromBytes([]byte{255, 255}) node1 := storj.NodeID{255, 255}
rt, cleanup := createRoutingTable(t, node1) rt, cleanup := createRoutingTable(t, node1)
defer cleanup() defer cleanup()
kadIDA := rt.createFirstBucketID() kadIDA := firstBucketID
node2 := teststorj.NodeIDFromBytes([]byte{191, 255}) node2 := storj.NodeID{191, 255}
node3 := teststorj.NodeIDFromBytes([]byte{127, 255}) node3 := storj.NodeID{127, 255}
node4 := teststorj.NodeIDFromBytes([]byte{63, 255}) node4 := storj.NodeID{63, 255}
node5 := teststorj.NodeIDFromBytes([]byte{159, 255}) node5 := storj.NodeID{159, 255}
node6 := teststorj.NodeIDFromBytes([]byte{0, 127}) node6 := storj.NodeID{0, 127}
resultA, err := rt.kadBucketHasRoom(kadIDA) resultA, err := rt.kadBucketHasRoom(kadIDA)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, resultA) assert.True(t, resultA)
@ -470,18 +417,18 @@ func TestKadBucketHasRoom(t *testing.T) {
} }
func TestGetNodeIDsWithinKBucket(t *testing.T) { func TestGetNodeIDsWithinKBucket(t *testing.T) {
nodeIDA := teststorj.NodeIDFromBytes([]byte{183, 255}) //[10110111, 1111111] nodeIDA := storj.NodeID{183, 255} //[10110111, 1111111]
rt, cleanup := createRoutingTable(t, nodeIDA) rt, cleanup := createRoutingTable(t, nodeIDA)
defer cleanup() defer cleanup()
kadIDA := rt.createFirstBucketID() kadIDA := firstBucketID
var kadIDB bucketID var kadIDB bucketID
copy(kadIDB[:], kadIDA[:]) copy(kadIDB[:], kadIDA[:])
kadIDB[0] = 127 kadIDB[0] = 127
now := time.Now() now := time.Now()
assert.NoError(t, rt.createOrUpdateKBucket(kadIDB, now)) assert.NoError(t, rt.createOrUpdateKBucket(kadIDB, now))
nodeIDB := teststorj.NodeIDFromBytes([]byte{111, 255}) //[01101111, 1111111] nodeIDB := storj.NodeID{111, 255} //[01101111, 1111111]
nodeIDC := teststorj.NodeIDFromBytes([]byte{47, 255}) //[00101111, 1111111] nodeIDC := storj.NodeID{47, 255} //[00101111, 1111111]
assert.NoError(t, rt.nodeBucketDB.Put(nodeIDB.Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(nodeIDB.Bytes(), []byte("")))
assert.NoError(t, rt.nodeBucketDB.Put(nodeIDC.Bytes(), []byte(""))) assert.NoError(t, rt.nodeBucketDB.Put(nodeIDC.Bytes(), []byte("")))
@ -567,7 +514,7 @@ func TestUnmarshalNodes(t *testing.T) {
func TestGetUnmarshaledNodesFromBucket(t *testing.T) { func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
nodeA := teststorj.MockNode("AA") nodeA := teststorj.MockNode("AA")
rt, cleanup := createRoutingTable(t, nodeA.Id) rt, cleanup := createRoutingTable(t, nodeA.Id)
bucketID := rt.createFirstBucketID() bucketID := firstBucketID
defer cleanup() defer cleanup()
nodeB := teststorj.MockNode("BB") nodeB := teststorj.MockNode("BB")
nodeC := teststorj.MockNode("CC") nodeC := teststorj.MockNode("CC")
@ -587,9 +534,9 @@ func TestGetUnmarshaledNodesFromBucket(t *testing.T) {
func TestGetKBucketRange(t *testing.T) { func TestGetKBucketRange(t *testing.T) {
rt, cleanup := createRoutingTable(t, storj.NodeID{}) rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup() defer cleanup()
idA := teststorj.NodeIDFromBytes([]byte{255, 255}) idA := storj.NodeID{255, 255}
idB := teststorj.NodeIDFromBytes([]byte{127, 255}) idB := storj.NodeID{127, 255}
idC := teststorj.NodeIDFromBytes([]byte{63, 255}) idC := storj.NodeID{63, 255}
assert.NoError(t, rt.kadBucketDB.Put(idA.Bytes(), []byte(""))) assert.NoError(t, rt.kadBucketDB.Put(idA.Bytes(), []byte("")))
assert.NoError(t, rt.kadBucketDB.Put(idB.Bytes(), []byte(""))) assert.NoError(t, rt.kadBucketDB.Put(idB.Bytes(), []byte("")))
assert.NoError(t, rt.kadBucketDB.Put(idC.Bytes(), []byte(""))) assert.NoError(t, rt.kadBucketDB.Put(idC.Bytes(), []byte("")))
@ -622,14 +569,6 @@ func TestGetKBucketRange(t *testing.T) {
} }
} }
func TestCreateFirstBucketID(t *testing.T) {
rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup()
x := rt.createFirstBucketID()
expected := []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}
assert.Equal(t, x[:], expected)
}
func TestBucketIDZeroValue(t *testing.T) { func TestBucketIDZeroValue(t *testing.T) {
// rt, cleanup := createRoutingTable(t, storj.NodeID{}) // rt, cleanup := createRoutingTable(t, storj.NodeID{})
// defer cleanup() // defer cleanup()
@ -641,9 +580,10 @@ func TestBucketIDZeroValue(t *testing.T) {
func TestDetermineLeafDepth(t *testing.T) { func TestDetermineLeafDepth(t *testing.T) {
rt, cleanup := createRoutingTable(t, storj.NodeID{}) rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup() defer cleanup()
idA := teststorj.NodeIDFromBytes([]byte{255, 255}) idA, idB, idC := storj.NodeID(firstBucketID), storj.NodeID(firstBucketID), storj.NodeID(firstBucketID)
idB := teststorj.NodeIDFromBytes([]byte{127, 255}) idA[0] = 255
idC := teststorj.NodeIDFromBytes([]byte{63, 255}) idB[0] = 127
idC[0] = 63
cases := []struct { cases := []struct {
testID string testID string
@ -689,117 +629,13 @@ func TestDetermineLeafDepth(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run(c.testID, func(t *testing.T) { t.Run(c.testID, func(t *testing.T) {
c.addNode() c.addNode()
d, err := rt.determineLeafDepth(keyToBucketID(c.id.Bytes())) d, err := rt.determineLeafDepth(bucketID(c.id))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.depth, d) assert.Equal(t, c.depth, d)
}) })
} }
} }
func padBucketID(b []byte, p byte) (bID bucketID) {
for i := range bID {
if len(b) > i {
bID[i] = b[i]
continue
}
bID[i] = p
}
return bID
}
func TestDetermineDifferingBitIndex(t *testing.T) {
rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup()
cases := []struct {
testID string
bucketID bucketID
key bucketID
expected int
err *errs.Class
}{
{testID: "A",
bucketID: padBucketID([]byte{191, 255}, 255),
key: padBucketID([]byte{255, 255}, 255),
expected: 1,
err: nil,
},
{testID: "B",
bucketID: padBucketID([]byte{255, 255}, 255),
key: padBucketID([]byte{191, 255}, 255),
expected: 1,
err: nil,
},
{testID: "C",
bucketID: padBucketID([]byte{95, 255}, 255),
key: padBucketID([]byte{127, 255}, 255),
expected: 2,
err: nil,
},
{testID: "D",
bucketID: padBucketID([]byte{95, 255}, 255),
key: padBucketID([]byte{79, 255}, 255),
expected: 3,
err: nil,
},
{testID: "E",
bucketID: padBucketID([]byte{95, 255}, 255),
key: padBucketID([]byte{63, 255}, 255),
expected: 2,
err: nil,
},
{testID: "F",
bucketID: padBucketID([]byte{95, 255}, 255),
key: padBucketID([]byte{79, 255}, 255),
expected: 3,
err: nil,
},
{testID: "G",
bucketID: padBucketID([]byte{255, 255}, 255),
key: padBucketID([]byte{255, 255}, 255),
expected: -2,
err: &RoutingErr,
},
{testID: "H",
bucketID: padBucketID([]byte{255, 255}, 255),
key: padBucketID([]byte{0, 0}, 0),
expected: -1,
err: nil,
},
{testID: "I",
bucketID: padBucketID([]byte{127, 255}, 255),
key: padBucketID([]byte{0, 0}, 0),
expected: 0,
err: nil,
},
{testID: "J",
bucketID: padBucketID([]byte{63, 255}, 255),
key: padBucketID([]byte{0, 0}, 0),
expected: 1,
err: nil,
},
{testID: "K",
bucketID: padBucketID([]byte{31, 255}, 255),
key: padBucketID([]byte{0, 0}, 0),
expected: 2,
err: nil,
},
{testID: "L",
bucketID: padBucketID([]byte{95, 255}, 255),
key: padBucketID([]byte{63, 255}, 255),
expected: 2,
err: nil,
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
diff, err := rt.determineDifferingBitIndex(c.bucketID, c.key)
assertErrClass(t, c.err, err)
assert.Equal(t, c.expected, diff)
})
}
}
func TestSplitBucket(t *testing.T) { func TestSplitBucket(t *testing.T) {
rt, cleanup := createRoutingTable(t, storj.NodeID{}) rt, cleanup := createRoutingTable(t, storj.NodeID{})
defer cleanup() defer cleanup()
@ -847,12 +683,3 @@ func TestSplitBucket(t *testing.T) {
}) })
} }
} }
func assertErrClass(t *testing.T, class *errs.Class, err error) {
t.Helper()
if class != nil {
assert.True(t, class.Has(err))
} else {
assert.NoError(t, err)
}
}

103
pkg/kademlia/utils.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package kademlia
import (
"math/bits"
"sort"
"github.com/zeebo/errs"
"storj.io/storj/pkg/storj"
"storj.io/storj/storage"
)
func cloneNodeIDs(ids storj.NodeIDList) storj.NodeIDList {
clone := make(storj.NodeIDList, len(ids))
copy(clone, ids)
return clone
}
// compareByXor compares left, right xorred by reference
func compareByXor(left, right, reference storj.NodeID) int {
for i, r := range reference {
a, b := left[i]^r, right[i]^r
if a != b {
if a < b {
return -1
}
return 1
}
}
return 0
}
func sortByXOR(nodeIDs storj.NodeIDList, ref storj.NodeID) {
sort.Slice(nodeIDs, func(i, k int) bool {
return compareByXor(nodeIDs[i], nodeIDs[k], ref) < 0
})
}
func keysToNodeIDs(keys storage.Keys) (ids storj.NodeIDList, err error) {
var idErrs []error
for _, k := range keys {
id, err := storj.NodeIDFromBytes(k[:])
if err != nil {
idErrs = append(idErrs, err)
}
ids = append(ids, id)
}
if err := errs.Combine(idErrs...); err != nil {
return nil, err
}
return ids, nil
}
func keyToBucketID(key storage.Key) (bID bucketID) {
copy(bID[:], key)
return bID
}
// xorNodeID returns the xor of each byte in NodeID
func xorNodeID(a, b storj.NodeID) storj.NodeID {
r := storj.NodeID{}
for i, av := range a {
r[i] = av ^ b[i]
}
return r
}
// xorBucketID returns the xor of each byte in bucketID
func xorBucketID(a, b bucketID) bucketID {
r := bucketID{}
for i, av := range a {
r[i] = av ^ b[i]
}
return r
}
// determineDifferingBitIndex: helper, returns the last bit differs starting from prefix to suffix
func determineDifferingBitIndex(bID, comparisonID bucketID) (int, error) {
if bID == comparisonID {
return -2, RoutingErr.New("compared two equivalent k bucket ids")
}
if comparisonID == emptyBucketID {
comparisonID = firstBucketID
}
xorID := xorBucketID(bID, comparisonID)
if xorID == firstBucketID {
return -1, nil
}
for i, v := range xorID {
if v != 0 {
return i*8 + 7 - bits.TrailingZeros8(v), nil
}
}
return -1, nil
}

127
pkg/kademlia/utils_test.go Normal file
View File

@ -0,0 +1,127 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package kademlia
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"storj.io/storj/pkg/storj"
)
func TestSortByXOR(t *testing.T) {
n1 := storj.NodeID{127, 255} //xor 0
n2 := storj.NodeID{143, 255} //xor 240
n3 := storj.NodeID{255, 255} //xor 128
n4 := storj.NodeID{191, 255} //xor 192
n5 := storj.NodeID{133, 255} //xor 250
unsorted := storj.NodeIDList{n1, n5, n2, n4, n3}
sortByXOR(unsorted, n1)
sorted := storj.NodeIDList{n1, n3, n4, n2, n5}
assert.Equal(t, sorted, unsorted)
}
func BenchmarkSortByXOR(b *testing.B) {
newNodeID := func() storj.NodeID {
var id storj.NodeID
rand.Read(id[:])
return id
}
nodes := []storj.NodeID{}
for k := 0; k < 1000; k++ {
nodes = append(nodes, newNodeID())
}
b.ResetTimer()
for m := 0; m < b.N; m++ {
rand.Shuffle(len(nodes), func(i, k int) {
nodes[i], nodes[k] = nodes[k], nodes[i]
})
sortByXOR(nodes, newNodeID())
}
}
func TestDetermineDifferingBitIndex(t *testing.T) {
filledID := func(a byte) bucketID {
id := firstBucketID
id[0] = a
return id
}
cases := []struct {
bucketID bucketID
key bucketID
expected int
}{
{
bucketID: filledID(191),
key: filledID(255),
expected: 1,
},
{
bucketID: filledID(255),
key: filledID(191),
expected: 1,
},
{
bucketID: filledID(95),
key: filledID(127),
expected: 2,
},
{
bucketID: filledID(95),
key: filledID(79),
expected: 3,
},
{
bucketID: filledID(95),
key: filledID(63),
expected: 2,
},
{
bucketID: filledID(95),
key: filledID(79),
expected: 3,
},
{
bucketID: filledID(255),
key: bucketID{},
expected: -1,
},
{
bucketID: filledID(127),
key: bucketID{},
expected: 0,
},
{
bucketID: filledID(63),
key: bucketID{},
expected: 1,
},
{
bucketID: filledID(31),
key: bucketID{},
expected: 2,
},
{
bucketID: filledID(95),
key: filledID(63),
expected: 2,
},
}
for i, c := range cases {
t.Logf("#%d. bucketID:%v key:%v\n", i, c.bucketID, c.key)
diff, err := determineDifferingBitIndex(c.bucketID, c.key)
assert.NoError(t, err)
assert.Equal(t, c.expected, diff)
}
diff, err := determineDifferingBitIndex(filledID(255), filledID(255))
assert.True(t, RoutingErr.Has(err))
assert.Equal(t, diff, -2)
}