diff --git a/pkg/kademlia/kademlia.go b/pkg/kademlia/kademlia.go index 16a1455a3..9ebb0ba3e 100644 --- a/pkg/kademlia/kademlia.go +++ b/pkg/kademlia/kademlia.go @@ -189,8 +189,16 @@ func (k *Kademlia) lookup(ctx context.Context, target dht.NodeID, opts lookupOpt // Ping checks that the provided node is still accessible on the network func (k *Kademlia) Ping(ctx context.Context, node pb.Node) (pb.Node, error) { - // TODO(coyle) - return pb.Node{}, nil + ok, err := k.nodeClient.Ping(ctx, node) + if err != nil { + return pb.Node{}, NodeErr.Wrap(err) + } + + if !ok { + return pb.Node{}, NodeErr.New("Failed pinging node") + } + + return node, nil } // FindNode looks up the provided NodeID first in the local Node, and if it is not found diff --git a/pkg/node/client.go b/pkg/node/client.go index c1fbb447f..92dc7652a 100644 --- a/pkg/node/client.go +++ b/pkg/node/client.go @@ -10,7 +10,6 @@ import ( "storj.io/storj/pkg/dht" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/provider" - "storj.io/storj/pkg/transport" ) //NodeClientErr is the class for all errors pertaining to node client operations @@ -18,19 +17,20 @@ var NodeClientErr = errs.Class("node client error") // NewNodeClient instantiates a node client func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (Client, error) { - client := transport.NewClient(identity) node := &Node{ - dht: dht, - self: self, - tc: client, - cache: NewConnectionPool(), + dht: dht, + self: self, + pool: NewConnectionPool(identity), } - node.cache.Init() + + node.pool.Init() + return node, nil } // Client is the Node client communication interface type Client interface { Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) + Ping(ctx context.Context, to pb.Node) (bool, error) Disconnect() error } diff --git a/pkg/node/connection_pool.go b/pkg/node/connection_pool.go index 05de82f80..0ad1dce35 100644 --- a/pkg/node/connection_pool.go +++ b/pkg/node/connection_pool.go @@ -4,10 +4,15 @@ package node import ( + "context" "sync" "github.com/zeebo/errs" + "google.golang.org/grpc" + "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/provider" + "storj.io/storj/pkg/transport" "storj.io/storj/pkg/utils" ) @@ -16,21 +21,31 @@ var Error = errs.Class("connection pool error") // ConnectionPool is the in memory pool of node connections type ConnectionPool struct { + tc transport.Client mu sync.RWMutex - cache map[string]interface{} + items map[string]*Conn } +// Conn is the connection that is stored in the connection pool +type Conn struct { + addr string + + dial sync.Once + client pb.NodesClient + grpc *grpc.ClientConn + err error +} + +// NewConn intitalizes a new Conn struct with the provided address, but does not iniate a connection +func NewConn(addr string) *Conn { return &Conn{addr: addr} } + // NewConnectionPool initializes a new in memory pool -func NewConnectionPool() *ConnectionPool { - return &ConnectionPool{} -} - -// Add takes a node ID as the key and a node client as the value to store -func (pool *ConnectionPool) Add(key string, value interface{}) error { - pool.mu.Lock() - defer pool.mu.Unlock() - pool.cache[key] = value - return nil +func NewConnectionPool(identity *provider.FullIdentity) *ConnectionPool { + return &ConnectionPool{ + tc: transport.NewClient(identity), + items: make(map[string]*Conn), + mu: sync.RWMutex{}, + } } // Get retrieves a node connection with the provided nodeID @@ -38,43 +53,71 @@ func (pool *ConnectionPool) Add(key string, value interface{}) error { func (pool *ConnectionPool) Get(key string) (interface{}, error) { pool.mu.Lock() defer pool.mu.Unlock() - return pool.cache[key], nil + + i, ok := pool.items[key] + if !ok { + return nil, nil + } + + return i, nil } -// Remove deletes a connection associated with the provided NodeID -func (pool *ConnectionPool) Remove(key string) error { +// Disconnect deletes a connection associated with the provided NodeID +func (pool *ConnectionPool) Disconnect(key string) error { pool.mu.Lock() defer pool.mu.Unlock() - pool.cache[key] = nil - return nil + + i, ok := pool.items[key] + if !ok { + return nil + } + + delete(pool.items, key) + + return i.grpc.Close() } -// Disconnect closes the connection to the node and removes it from the connection pool -func (pool *ConnectionPool) Disconnect() error { - var err error - var errs []error - for k, v := range pool.cache { - conn, ok := v.(interface{ Close() error }) - if !ok { - err = Error.New("connection pool value not a grpc client connection") - errs = append(errs, err) - continue +// Dial connects to the node with the given ID and Address returning a gRPC Node Client +func (pool *ConnectionPool) Dial(ctx context.Context, n *pb.Node) (pb.NodesClient, error) { + id := n.GetId() + pool.mu.Lock() + conn, ok := pool.items[id] + if !ok { + conn = NewConn(n.GetAddress().Address) + pool.items[id] = conn + } + pool.mu.Unlock() + + conn.dial.Do(func() { + conn.grpc, conn.err = pool.tc.DialNode(ctx, n) + if conn.err != nil { + return } - err = conn.Close() - if err != nil { - errs = append(errs, Error.Wrap(err)) - continue - } - err = pool.Remove(k) - if err != nil { + + conn.client = pb.NewNodesClient(conn.grpc) + }) + + if conn.err != nil { + return nil, conn.err + } + + return conn.client, nil +} + +// DisconnectAll closes all connections nodes and removes them from the connection pool +func (pool *ConnectionPool) DisconnectAll() error { + errs := []error{} + for k := range pool.items { + if err := pool.Disconnect(k); err != nil { errs = append(errs, Error.Wrap(err)) continue } } + return utils.CombineErrors(errs...) } // Init initializes the cache func (pool *ConnectionPool) Init() { - pool.cache = make(map[string]interface{}) + pool.items = make(map[string]*Conn) } diff --git a/pkg/node/connection_pool_test.go b/pkg/node/connection_pool_test.go index e43507e86..eb5406888 100644 --- a/pkg/node/connection_pool_test.go +++ b/pkg/node/connection_pool_test.go @@ -4,32 +4,31 @@ package node import ( + "context" "sync" "testing" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "storj.io/storj/pkg/pb" ) -type TestFoo struct { - called string -} - func TestGet(t *testing.T) { cases := []struct { pool *ConnectionPool key string - expected TestFoo + expected Conn expectedError error }{ { pool: func() *ConnectionPool { - p := NewConnectionPool() + p := NewConnectionPool(newTestIdentity(t)) p.Init() - assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"})) + p.items["foo"] = &Conn{addr: "foo"} return p }(), key: "foo", - expected: TestFoo{called: "hoot"}, + expected: Conn{addr: "foo"}, expectedError: nil, }, } @@ -38,42 +37,16 @@ func TestGet(t *testing.T) { v := &cases[i] test, err := v.pool.Get(v.key) assert.Equal(t, v.expectedError, err) - assert.Equal(t, v.expected, test) + + assert.Equal(t, v.expected.addr, test.(*Conn).addr) } } -func TestAdd(t *testing.T) { - cases := []struct { - pool ConnectionPool - key string - value TestFoo - expected TestFoo - expectedError error - }{ - { - pool: ConnectionPool{ - mu: sync.RWMutex{}, - cache: map[string]interface{}{}}, - key: "foo", - value: TestFoo{called: "hoot"}, - expected: TestFoo{called: "hoot"}, - expectedError: nil, - }, - } +func TestDisconnect(t *testing.T) { - for i := range cases { - v := &cases[i] - err := v.pool.Add(v.key, v.value) - assert.Equal(t, v.expectedError, err) - - test, err := v.pool.Get(v.key) - assert.Equal(t, v.expectedError, err) - - assert.Equal(t, v.expected, test) - } -} - -func TestRemove(t *testing.T) { + conn, err := grpc.Dial("127.0.0.1:0", grpc.WithInsecure()) + assert.NoError(t, err) + // gc.Close = func() error { return nil } cases := []struct { pool ConnectionPool key string @@ -83,7 +56,8 @@ func TestRemove(t *testing.T) { { pool: ConnectionPool{ mu: sync.RWMutex{}, - cache: map[string]interface{}{"foo": TestFoo{called: "hoot"}}}, + items: map[string]*Conn{"foo": &Conn{grpc: conn}}, + }, key: "foo", expected: nil, expectedError: nil, @@ -92,7 +66,7 @@ func TestRemove(t *testing.T) { for i := range cases { v := &cases[i] - err := v.pool.Remove(v.key) + err := v.pool.Disconnect(v.key) assert.Equal(t, v.expectedError, err) test, err := v.pool.Get(v.key) @@ -101,3 +75,38 @@ func TestRemove(t *testing.T) { assert.Equal(t, v.expected, test) } } + +func TestDial(t *testing.T) { + cases := []struct { + pool *ConnectionPool + node *pb.Node + expectedError error + expected *Conn + }{ + { + pool: NewConnectionPool(newTestIdentity(t)), + node: &pb.Node{Id: "foo", Address: &pb.NodeAddress{Address: "127.0.0.1:0"}}, + expected: nil, + expectedError: nil, + }, + } + + for _, v := range cases { + wg := sync.WaitGroup{} + wg.Add(4) + go testDial(t, &wg, v.pool, v.node, v.expectedError) + go testDial(t, &wg, v.pool, v.node, v.expectedError) + go testDial(t, &wg, v.pool, v.node, v.expectedError) + go testDial(t, &wg, v.pool, v.node, v.expectedError) + wg.Wait() + } + +} + +func testDial(t *testing.T, wg *sync.WaitGroup, p *ConnectionPool, n *pb.Node, eerr error) { + defer wg.Done() + ctx := context.Background() + actual, err := p.Dial(ctx, n) + assert.Equal(t, eerr, err) + assert.NotNil(t, actual) +} diff --git a/pkg/node/node.go b/pkg/node/node.go index 3df26bece..7f8d07a36 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -5,64 +5,54 @@ package node import ( "context" - "log" - "google.golang.org/grpc" "storj.io/storj/pkg/dht" "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/transport" ) // Node is the storj definition for a node in the network type Node struct { - dht dht.DHT - self pb.Node - tc transport.Client - cache *ConnectionPool + dht dht.DHT + self pb.Node + pool *ConnectionPool } // Lookup queries nodes looking for a particular node in the network func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) { - v, err := n.cache.Get(to.GetId()) + c, err := n.pool.Dial(ctx, &to) if err != nil { - return nil, err + return nil, NodeClientErr.Wrap(err) } - var conn *grpc.ClientConn - if c, ok := v.(*grpc.ClientConn); ok { - conn = c - } else { - c, err := n.tc.DialNode(ctx, &to) - if err != nil { - return nil, err - } - - if err := n.cache.Add(to.GetId(), c); err != nil { - log.Printf("Error %s occurred adding %s to cache", err, to.GetId()) - } - conn = c - } - - c := pb.NewNodesClient(conn) resp, err := c.Query(ctx, &pb.QueryRequest{Limit: 20, Sender: &n.self, Target: &find, Pingback: true}) if err != nil { - return nil, err + return nil, NodeClientErr.Wrap(err) } rt, err := n.dht.GetRoutingTable(ctx) if err != nil { - return nil, err + return nil, NodeClientErr.Wrap(err) } if err := rt.ConnectionSuccess(&to); err != nil { - return nil, err + return nil, NodeClientErr.Wrap(err) } return resp.Response, nil } -// Disconnect closes connections within the cache -func (n *Node) Disconnect() error { - return n.cache.Disconnect() +// Ping attempts to establish a connection with a node to verify it is alive +func (n *Node) Ping(ctx context.Context, to pb.Node) (bool, error) { + _, err := n.pool.Dial(ctx, &to) + if err != nil { + return false, NodeClientErr.Wrap(err) + } + + return true, nil +} + +// Disconnect closes all connections within the pool +func (n *Node) Disconnect() error { + return n.pool.DisconnectAll() } diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index f1f7dc229..d04ff7121 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -40,7 +40,8 @@ func TestLookup(t *testing.T) { v.to = pb.Node{Id: NewNodeID(t), Address: &pb.NodeAddress{Address: lis.Addr().String()}} - srv, mock, err := newTestServer(ctx) + id := newTestIdentity(t) + srv, mock, err := newTestServer(ctx, &mockNodeServer{queryCalled: 0}, id) assert.NoError(t, err) go func() { assert.NoError(t, srv.Serve(lis)) }() defer srv.Stop() @@ -62,30 +63,63 @@ func TestLookup(t *testing.T) { _, err = nc.Lookup(ctx, v.to, v.find) assert.Equal(t, v.expectedErr, err) - assert.Equal(t, 1, mock.queryCalled) + assert.Equal(t, 1, mock.(*mockNodeServer).queryCalled) } } -func newTestServer(ctx context.Context) (*grpc.Server, *mockNodeServer, error) { - ca, err := provider.NewCA(ctx, 12, 4) - if err != nil { - return nil, nil, err +func TestPing(t *testing.T) { + ctx := context.Background() + cases := []struct { + self pb.Node + toID string + toIdentity *provider.FullIdentity + expectedErr error + }{ + { + self: pb.Node{Id: "hello", Address: &pb.NodeAddress{Address: ":7070"}}, + toID: "", + toIdentity: newTestIdentity(t), + expectedErr: nil, + }, } - identity, err := ca.NewIdentity() - if err != nil { - return nil, nil, err + + for _, v := range cases { + lis, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + // new mock DHT for node client + ctrl := gomock.NewController(t) + mdht := mock_dht.NewMockDHT(ctrl) + // set up a node server + srv := NewServer(mdht) + + msrv, _, err := newTestServer(ctx, srv, v.toIdentity) + assert.NoError(t, err) + // start gRPC server + + go func() { assert.NoError(t, msrv.Serve(lis)) }() + defer msrv.Stop() + + nc, err := NewNodeClient(v.toIdentity, v.self, mdht) + assert.NoError(t, err) + + id := ID(v.toIdentity.ID) + ok, err := nc.Ping(ctx, pb.Node{Id: id.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}}) + assert.Equal(t, v.expectedErr, err) + assert.Equal(t, ok, true) } +} + +func newTestServer(ctx context.Context, ns pb.NodesServer, identity *provider.FullIdentity) (*grpc.Server, pb.NodesServer, error) { identOpt, err := identity.ServerOption() if err != nil { return nil, nil, err } grpcServer := grpc.NewServer(identOpt) - mn := &mockNodeServer{queryCalled: 0} - pb.RegisterNodesServer(grpcServer, mn) + pb.RegisterNodesServer(grpcServer, ns) - return grpcServer, mn, nil + return grpcServer, ns, nil } @@ -106,3 +140,12 @@ func NewNodeID(t *testing.T) string { return id.String() } + +func newTestIdentity(t *testing.T) *provider.FullIdentity { + ca, err := provider.NewCA(ctx, 12, 4) + assert.NoError(t, err) + identity, err := ca.NewIdentity() + assert.NoError(t, err) + + return identity +} diff --git a/pkg/provider/identity.go b/pkg/provider/identity.go index 22ea45ba3..eeae37f0c 100644 --- a/pkg/provider/identity.go +++ b/pkg/provider/identity.go @@ -301,6 +301,7 @@ func (fi *FullIdentity) ServerOption(pcvFuncs ...peertls.PeerCertVerificationFun // DialOption returns a grpc `DialOption` for making outgoing connections // to the node with this peer identity func (fi *FullIdentity) DialOption() (grpc.DialOption, error) { + // TODO(coyle): add ID ch := [][]byte{fi.Leaf.Raw, fi.CA.Raw} c, err := peertls.TLSCert(ch, fi.Leaf, fi.Key) if err != nil { @@ -312,6 +313,10 @@ func (fi *FullIdentity) DialOption() (grpc.DialOption, error) { InsecureSkipVerify: true, VerifyPeerCertificate: peertls.VerifyPeerFunc( peertls.VerifyPeerCertChains, + func(_ [][]byte, parsedChains [][]*x509.Certificate) error { + return nil + }, + // TODO(coyle): Check that the ID of the node we are dialing is the owner of the certificate. ), } diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 2b590185f..964d8b01d 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -29,7 +29,7 @@ func (o *Transport) DialNode(ctx context.Context, node *pb.Node) (conn *grpc.Cli if node.Address == nil || node.Address.Address == "" { return nil, Error.New("no address") } - + // TODO(coyle): pass ID dialOpt, err := o.identity.DialOption() if err != nil { return nil, err