pkg/kademlia: clean up peer discovery (#2252)

This commit is contained in:
JT Olio 2019-06-26 07:16:46 -06:00 committed by Egon Elbre
parent 3925e84580
commit fbe9696e92
9 changed files with 201 additions and 151 deletions

View File

@ -14,6 +14,7 @@ import (
"storj.io/storj/internal/sync2"
"storj.io/storj/pkg/identity"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
"storj.io/storj/pkg/transport"
)
@ -47,29 +48,37 @@ func (dialer *Dialer) Close() error {
}
// Lookup queries ask about find, and also sends information about self.
func (dialer *Dialer) Lookup(ctx context.Context, self pb.Node, ask pb.Node, find pb.Node) (_ []*pb.Node, err error) {
// If self is nil, pingback will be false.
func (dialer *Dialer) Lookup(ctx context.Context, self *pb.Node, ask pb.Node, find storj.NodeID, limit int) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
if !dialer.limit.Lock() {
return nil, context.Canceled
}
defer dialer.limit.Unlock()
req := pb.QueryRequest{
Limit: int64(limit),
Target: &pb.Node{Id: find}, // TODO: should not be a Node protobuf!
}
if self != nil {
req.Pingback = true
req.Sender = self
}
conn, err := dialer.dialNode(ctx, ask)
if err != nil {
return nil, err
}
defer func() {
err = errs.Combine(err, conn.disconnect())
}()
resp, err := conn.client.Query(ctx, &pb.QueryRequest{
Limit: 20, // TODO: should not be hardcoded, but instead kademlia k value, routing table depth, etc
Sender: &self,
Target: &find,
Pingback: true, // should only be true during bucket refreshing
})
resp, err := conn.client.Query(ctx, &req)
if err != nil {
return nil, errs.Combine(err, conn.disconnect())
return nil, err
}
return resp.Response, conn.disconnect()
return resp.Response, nil
}
// PingNode pings target.

View File

@ -105,7 +105,8 @@ func TestDialer(t *testing.T) {
for _, target := range peers {
errTag := fmt.Errorf("lookup peer:%s target:%s", peer.ID(), target.ID())
results, err := dialer.Lookup(ctx, self.Local().Node, peer.Local().Node, target.Local().Node)
selfnode := self.Local().Node
results, err := dialer.Lookup(ctx, &selfnode, peer.Local().Node, target.Local().Node.Id, self.Kademlia.RoutingTable.K())
if err != nil {
return errs.Combine(errTag, err)
}
@ -145,7 +146,8 @@ func TestDialer(t *testing.T) {
group.Go(func() error {
errTag := fmt.Errorf("invalid lookup peer:%s target:%s", peer.ID(), target)
results, err := dialer.Lookup(ctx, self.Local().Node, peer.Local().Node, pb.Node{Id: target})
selfnode := self.Local().Node
results, err := dialer.Lookup(ctx, &selfnode, peer.Local().Node, target, self.Kademlia.RoutingTable.K())
if err != nil {
return errs.Combine(errTag, err)
}
@ -275,7 +277,8 @@ func TestSlowDialerHasTimeout(t *testing.T) {
peer := peer
group.Go(func() error {
for _, target := range peers {
_, err := dialer.Lookup(ctx, self.Local().Node, peer.Local().Node, target.Local().Node)
selfnode := self.Local().Node
_, err := dialer.Lookup(ctx, &selfnode, peer.Local().Node, target.Local().Node.Id, self.Kademlia.RoutingTable.K())
if !transport.Error.Has(err) || errs.Unwrap(err) != context.DeadlineExceeded {
return errs.New("invalid error: %v (peer:%s target:%s)", err, peer.ID(), target.ID())
}

View File

@ -32,17 +32,9 @@ var (
NodeNotFound = errs.Class("node not found")
// TODO: shouldn't default to TCP but not sure what to do yet
defaultTransport = pb.NodeTransport_TCP_TLS_GRPC
defaultRetries = 3
mon = monkit.Package()
)
type discoveryOptions struct {
concurrency int
retries int
bootstrap bool
bootstrapNodes []pb.Node
}
// Kademlia is an implementation of kademlia adhering to the DHT interface.
type Kademlia struct {
log *zap.Logger
@ -185,6 +177,17 @@ func (k *Kademlia) Bootstrap(ctx context.Context) (err error) {
continue
}
// FetchPeerIdentityUnverified uses transport.DialAddress, which should be
// enough to have the TransportObservers find out about this node. Unfortunately,
// getting DialAddress to be able to grab the node id seems challenging with gRPC.
// The way FetchPeerIdentityUnverified does is is to do a basic ping request, which
// we have now done. Let's tell all the transport observers now.
// TODO: remove the explicit transport observer notification
k.dialer.transport.AlertSuccess(ctx, &pb.Node{
Id: ident.ID,
Address: node.Address,
})
k.routingTable.mutex.Lock()
node.Id = ident.ID
k.bootstrapNodes[i] = node
@ -201,7 +204,7 @@ func (k *Kademlia) Bootstrap(ctx context.Context) (err error) {
k.routingTable.mutex.Lock()
id := k.routingTable.self.Id
k.routingTable.mutex.Unlock()
_, err := k.lookup(ctx, id, true)
_, err := k.lookup(ctx, id)
if err != nil {
errGroup.Add(err)
continue
@ -280,37 +283,33 @@ func (k *Kademlia) FindNode(ctx context.Context, nodeID storj.NodeID) (_ pb.Node
}
defer k.lookups.Done()
return k.lookup(ctx, nodeID, false)
results, err := k.lookup(ctx, nodeID)
if err != nil {
return pb.Node{}, err
}
if len(results) < 1 {
return pb.Node{}, NodeNotFound.New("")
}
return *results[0], nil
}
//lookup initiates a kadmelia node lookup
func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap bool) (_ pb.Node, err error) {
func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
if !k.lookups.Start() {
return pb.Node{}, context.Canceled
return nil, context.Canceled
}
defer k.lookups.Done()
kb := k.routingTable.K()
var nodes []*pb.Node
if isBootstrap {
for _, bn := range k.bootstrapNodes {
bn := bn
nodes = append(nodes, &bn)
}
} else {
var err error
nodes, err = k.routingTable.FindNear(ctx, nodeID, kb)
if err != nil {
return pb.Node{}, err
}
}
lookup := newPeerDiscovery(k.log, k.routingTable.Local().Node, nodes, k.dialer, nodeID, discoveryOptions{
concurrency: k.alpha, retries: defaultRetries, bootstrap: isBootstrap, bootstrapNodes: k.bootstrapNodes,
})
target, err := lookup.Run(ctx)
nodes, err := k.routingTable.FindNear(ctx, nodeID, k.routingTable.K())
if err != nil {
return pb.Node{}, err
return nil, err
}
self := k.routingTable.Local().Node
lookup := newPeerDiscovery(k.log, k.dialer, nodeID, nodes, k.routingTable.K(), k.alpha, &self)
results, err := lookup.Run(ctx)
if err != nil {
return nil, err
}
bucket, err := k.routingTable.getKBucketID(ctx, nodeID)
if err != nil {
@ -321,13 +320,7 @@ func (k *Kademlia) lookup(ctx context.Context, nodeID storj.NodeID, isBootstrap
k.log.Warn("Error updating bucket timestamp in kad lookup")
}
}
if target == nil {
if isBootstrap {
return pb.Node{}, nil
}
return pb.Node{}, NodeNotFound.New("")
}
return *target, nil
return results, nil
}
// GetNodesWithinKBucket returns all the routing nodes in the specified k-bucket

View File

@ -118,7 +118,7 @@ func TestPeerDiscovery(t *testing.T) {
},
}
for _, v := range cases {
_, err := k.lookup(ctx, v.target, true)
_, err := k.lookup(ctx, v.target)
assert.Equal(t, v.expectedErr, err)
}
}

View File

@ -18,10 +18,11 @@ import (
type peerDiscovery struct {
log *zap.Logger
dialer *Dialer
self pb.Node
target storj.NodeID
opts discoveryOptions
dialer *Dialer
self *pb.Node
target storj.NodeID
k int
concurrency int
cond sync.Cond
queue discoveryQueue
@ -30,36 +31,35 @@ type peerDiscovery struct {
// ErrMaxRetries is used when a lookup has been retried the max number of times
var ErrMaxRetries = errs.Class("max retries exceeded for id:")
func newPeerDiscovery(log *zap.Logger, self pb.Node, nodes []*pb.Node, dialer *Dialer, target storj.NodeID, opts discoveryOptions) *peerDiscovery {
func newPeerDiscovery(log *zap.Logger, dialer *Dialer, target storj.NodeID, startingNodes []*pb.Node, k, alpha int, self *pb.Node) *peerDiscovery {
discovery := &peerDiscovery{
log: log,
dialer: dialer,
self: self,
target: target,
opts: opts,
cond: sync.Cond{L: &sync.Mutex{}},
queue: *newDiscoveryQueue(opts.concurrency),
log: log,
dialer: dialer,
self: self,
target: target,
k: k,
concurrency: alpha,
cond: sync.Cond{L: &sync.Mutex{}},
queue: *newDiscoveryQueue(target, k),
}
discovery.queue.Insert(target, nodes...)
discovery.queue.Insert(startingNodes...)
return discovery
}
func (lookup *peerDiscovery) Run(ctx context.Context) (target *pb.Node, err error) {
func (lookup *peerDiscovery) Run(ctx context.Context) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
if lookup.queue.Len() == 0 {
return nil, nil // TODO: should we return an error here?
if lookup.queue.Unqueried() == 0 {
return nil, nil
}
// protected by `lookup.cond.L`
working := 0
allDone := false
target = nil
wg := sync.WaitGroup{}
wg.Add(lookup.opts.concurrency)
defer wg.Wait()
wg.Add(lookup.concurrency)
for i := 0; i < lookup.opts.concurrency; i++ {
for i := 0; i < lookup.concurrency; i++ {
go func() {
defer wg.Done()
for {
@ -73,13 +73,7 @@ func (lookup *peerDiscovery) Run(ctx context.Context) (target *pb.Node, err erro
return
}
next = lookup.queue.Closest()
if !lookup.opts.bootstrap && next != nil && next.Id == lookup.target {
allDone = true
target = next
break // closest node is the target and is already in routing table (i.e. no lookup required)
}
next = lookup.queue.ClosestUnqueried()
if next != nil {
working++
@ -89,13 +83,11 @@ func (lookup *peerDiscovery) Run(ctx context.Context) (target *pb.Node, err erro
lookup.cond.Wait()
}
lookup.cond.L.Unlock()
neighbors, err := lookup.dialer.Lookup(ctx, lookup.self, *next, pb.Node{Id: lookup.target})
if err != nil && !isDone(ctx) {
// TODO: reenable retry after fixing logic
// ok := lookup.queue.Reinsert(lookup.target, next, lookup.opts.retries)
ok := false
if !ok {
neighbors, err := lookup.dialer.Lookup(ctx, lookup.self, *next, lookup.target, lookup.k)
if err != nil {
lookup.queue.QueryFailure(next)
if !isDone(ctx) {
lookup.log.Debug("connecting to node failed",
zap.Any("target", lookup.target),
zap.Any("dial-node", next.Id),
@ -103,24 +95,26 @@ func (lookup *peerDiscovery) Run(ctx context.Context) (target *pb.Node, err erro
zap.Error(err),
)
}
} else {
lookup.queue.QuerySuccess(next, neighbors...)
}
lookup.queue.Insert(lookup.target, neighbors...)
lookup.cond.L.Lock()
working--
allDone = allDone || isDone(ctx) || working == 0 && lookup.queue.Len() == 0
allDone = allDone || isDone(ctx) || (working == 0 && lookup.queue.Unqueried() == 0)
lookup.cond.L.Unlock()
lookup.cond.Broadcast()
}
}()
}
wg.Wait()
err = ctx.Err()
if err == context.Canceled {
err = nil
}
return target, err
return lookup.queue.ClosestQueried(), err
}
func isDone(ctx context.Context) bool {
@ -132,11 +126,21 @@ func isDone(ctx context.Context) bool {
}
}
type queueState int
const (
stateUnqueried queueState = iota
stateQuerying
stateSuccess
stateFailure
)
// discoveryQueue is a limited priority queue for nodes with xor distance
type discoveryQueue struct {
target storj.NodeID
maxLen int
mu sync.Mutex
added map[storj.NodeID]int
state map[storj.NodeID]queueState
items []queueItem
}
@ -147,57 +151,37 @@ type queueItem struct {
}
// newDiscoveryQueue returns a items with priority based on XOR from targetBytes
func newDiscoveryQueue(size int) *discoveryQueue {
func newDiscoveryQueue(target storj.NodeID, size int) *discoveryQueue {
return &discoveryQueue{
added: make(map[storj.NodeID]int),
target: target,
state: make(map[storj.NodeID]queueState),
maxLen: size,
}
}
// Insert adds nodes into the queue.
func (queue *discoveryQueue) Insert(target storj.NodeID, nodes ...*pb.Node) {
func (queue *discoveryQueue) Insert(nodes ...*pb.Node) {
queue.mu.Lock()
defer queue.mu.Unlock()
queue.insert(nodes...)
}
unique := nodes[:0]
// insert requires the mutex to be locked
func (queue *discoveryQueue) insert(nodes ...*pb.Node) {
for _, node := range nodes {
if _, added := queue.added[node.Id]; added {
// TODO: empty node ids should be semantically different from the
// technically valid node id that is all zeros
if node.Id == (storj.NodeID{}) {
continue
}
unique = append(unique, node)
}
queue.insert(target, unique...)
// update counts for the new items that are in the queue
for _, item := range queue.items {
if _, added := queue.added[item.node.Id]; !added {
queue.added[item.node.Id] = 1
if _, added := queue.state[node.Id]; added {
continue
}
}
}
queue.state[node.Id] = stateUnqueried
// Reinsert adds a Nodes into the queue, only if it's has been added less than limit times.
func (queue *discoveryQueue) Reinsert(target storj.NodeID, node *pb.Node, limit int) bool {
queue.mu.Lock()
defer queue.mu.Unlock()
nodeID := node.Id
if queue.added[nodeID] >= limit {
return false
}
queue.added[nodeID]++
queue.insert(target, node)
return true
}
// insert must hold lock while adding
func (queue *discoveryQueue) insert(target storj.NodeID, nodes ...*pb.Node) {
for _, node := range nodes {
queue.items = append(queue.items, queueItem{
node: node,
priority: xorNodeID(target, node.Id),
priority: xorNodeID(queue.target, node.Id),
})
}
@ -210,24 +194,62 @@ func (queue *discoveryQueue) insert(target storj.NodeID, nodes ...*pb.Node) {
}
}
// Closest returns the closest item in the queue
func (queue *discoveryQueue) Closest() *pb.Node {
// ClosestUnqueried returns the closest unqueried item in the queue
func (queue *discoveryQueue) ClosestUnqueried() *pb.Node {
queue.mu.Lock()
defer queue.mu.Unlock()
if len(queue.items) == 0 {
return nil
for _, item := range queue.items {
if queue.state[item.node.Id] == stateUnqueried {
queue.state[item.node.Id] = stateQuerying
return item.node
}
}
var item queueItem
item, queue.items = queue.items[0], queue.items[1:]
return item.node
return nil
}
// Len returns the number of items in the queue
func (queue *discoveryQueue) Len() int {
// ClosestQueried returns the closest queried items in the queue
func (queue *discoveryQueue) ClosestQueried() []*pb.Node {
queue.mu.Lock()
defer queue.mu.Unlock()
return len(queue.items)
rv := make([]*pb.Node, 0, len(queue.items))
for _, item := range queue.items {
if queue.state[item.node.Id] == stateSuccess {
rv = append(rv, item.node)
}
}
return rv
}
// QuerySuccess marks the node as successfully queried, and adds the results to the queue
// QuerySuccess marks nodes with a zero node ID as ignored, and ignores incoming
// nodes with a zero id.
func (queue *discoveryQueue) QuerySuccess(node *pb.Node, nodes ...*pb.Node) {
queue.mu.Lock()
defer queue.mu.Unlock()
queue.state[node.Id] = stateSuccess
queue.insert(nodes...)
}
// QueryFailure marks the node as failing query
func (queue *discoveryQueue) QueryFailure(node *pb.Node) {
queue.mu.Lock()
queue.state[node.Id] = stateFailure
queue.mu.Unlock()
}
// Unqueried returns the number of unqueried items in the queue
func (queue *discoveryQueue) Unqueried() (amount int) {
queue.mu.Lock()
defer queue.mu.Unlock()
for _, item := range queue.items {
if queue.state[item.node.Id] == stateUnqueried {
amount++
}
}
return amount
}

View File

@ -46,17 +46,17 @@ func TestDiscoveryQueue(t *testing.T) {
// t.Logf("%08b,%08b -> %08b,%08b", node.Id[0], node.Id[1], xor[0], xor[1])
// }
queue := newDiscoveryQueue(6)
queue.Insert(target, nodes...)
queue := newDiscoveryQueue(target, 6)
queue.Insert(nodes...)
assert.Equal(t, queue.Len(), 6)
assert.Equal(t, queue.Unqueried(), 6)
for i, expect := range expected {
node := queue.Closest()
node := queue.ClosestUnqueried()
assert.Equal(t, node.Id, expect.Id, strconv.Itoa(i))
}
assert.Nil(t, queue.Closest())
assert.Nil(t, queue.ClosestUnqueried())
}
func TestDiscoveryQueueRandom(t *testing.T) {
@ -78,20 +78,20 @@ func TestDiscoveryQueueRandom(t *testing.T) {
initial = append(initial, &pb.Node{Id: nodeID})
}
queue := newDiscoveryQueue(maxLen)
queue.Insert(target, initial...)
queue := newDiscoveryQueue(target, maxLen)
queue.Insert(initial...)
for k := 0; k < 10; k++ {
var nodeID storj.NodeID
_, _ = r.Read(nodeID[:])
queue.Insert(target, &pb.Node{Id: nodeID})
queue.Insert(&pb.Node{Id: nodeID})
}
assert.Equal(t, queue.Len(), maxLen)
assert.Equal(t, queue.Unqueried(), maxLen)
previousPriority := storj.NodeID{}
for queue.Len() > 0 {
next := queue.Closest()
for queue.Unqueried() > 0 {
next := queue.ClosestUnqueried()
priority := xorNodeID(target, next.Id)
// ensure that priority is monotonically increasing
assert.False(t, priority.Less(previousPriority))

View File

@ -57,6 +57,18 @@ func (client *slowTransport) WithObservers(obs ...Observer) Client {
return &slowTransport{client.client.WithObservers(obs...), client.network}
}
// AlertSuccess implements the transport.Client interface
func (client *slowTransport) AlertSuccess(ctx context.Context, node *pb.Node) {
defer mon.Task()(&ctx)(nil)
client.client.AlertSuccess(ctx, node)
}
// AlertFail implements the transport.Client interface
func (client *slowTransport) AlertFail(ctx context.Context, node *pb.Node, err error) {
defer mon.Task()(&ctx)(nil)
client.client.AlertFail(ctx, node, err)
}
// DialOptions returns options such that it will use simulated network parameters
func (network *SimulatedNetwork) DialOptions() []grpc.DialOption {
return []grpc.DialOption{grpc.WithContextDialer(network.GRPCDialContext)}

View File

@ -28,6 +28,8 @@ type Client interface {
DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error)
Identity() *identity.FullIdentity
WithObservers(obs ...Observer) Client
AlertSuccess(ctx context.Context, node *pb.Node)
AlertFail(ctx context.Context, node *pb.Node, err error)
}
// Timeouts contains all of the timeouts configurable for a transport
@ -101,11 +103,11 @@ func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ..
if err == context.Canceled {
return nil, err
}
alertFail(timedCtx, transport.observers, node, err)
transport.AlertFail(timedCtx, node, err)
return nil, Error.Wrap(err)
}
alertSuccess(timedCtx, transport.observers, node)
transport.AlertSuccess(timedCtx, node)
return conn, nil
}
@ -134,6 +136,8 @@ func (transport *Transport) DialAddress(ctx context.Context, address string, opt
timedCtx, cancel := context.WithTimeout(ctx, transport.timeouts.Dial)
defer cancel()
// TODO: this should also call alertFail or alertSuccess with the node id. We should be able
// to get gRPC to give us the node id after dialing?
conn, err = grpc.DialContext(timedCtx, address, options...)
if err == context.Canceled {
return nil, err
@ -154,14 +158,18 @@ func (transport *Transport) WithObservers(obs ...Observer) Client {
return tr
}
func alertFail(ctx context.Context, obs []Observer, node *pb.Node, err error) {
for _, o := range obs {
// AlertFail alerts any subscribed observers of the failure 'err' for 'node'
func (transport *Transport) AlertFail(ctx context.Context, node *pb.Node, err error) {
defer mon.Task()(&ctx)(nil)
for _, o := range transport.observers {
o.ConnFailure(ctx, node, err)
}
}
func alertSuccess(ctx context.Context, obs []Observer, node *pb.Node) {
for _, o := range obs {
// AlertSuccess alerts any subscribed observers of success for 'node'
func (transport *Transport) AlertSuccess(ctx context.Context, node *pb.Node) {
defer mon.Task()(&ctx)(nil)
for _, o := range transport.observers {
o.ConnSuccess(ctx, node)
}
}

View File

@ -26,6 +26,9 @@ func TestGetSignee(t *testing.T) {
planet.Start(ctx)
// make sure nodes are refreshed in db
planet.Satellites[0].Discovery.Service.Refresh.TriggerWait()
trust := planet.StorageNodes[0].Storage2.Trust
canceledContext, cancel := context.WithCancel(ctx)