Added ping support to node client (#491)

* added ping support to node client

* Added tests to Ping

* Added connection creation responsibility to Connection Pool
This commit is contained in:
Dennis Coyle 2018-10-26 12:38:22 -04:00 committed by GitHub
parent 79e1ed5d1f
commit 3b7b2afb1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 227 additions and 129 deletions

View File

@ -189,8 +189,16 @@ func (k *Kademlia) lookup(ctx context.Context, target dht.NodeID, opts lookupOpt
// Ping checks that the provided node is still accessible on the network
func (k *Kademlia) Ping(ctx context.Context, node pb.Node) (pb.Node, error) {
// TODO(coyle)
return pb.Node{}, nil
ok, err := k.nodeClient.Ping(ctx, node)
if err != nil {
return pb.Node{}, NodeErr.Wrap(err)
}
if !ok {
return pb.Node{}, NodeErr.New("Failed pinging node")
}
return node, nil
}
// FindNode looks up the provided NodeID first in the local Node, and if it is not found

View File

@ -10,7 +10,6 @@ import (
"storj.io/storj/pkg/dht"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/provider"
"storj.io/storj/pkg/transport"
)
//NodeClientErr is the class for all errors pertaining to node client operations
@ -18,19 +17,20 @@ var NodeClientErr = errs.Class("node client error")
// NewNodeClient instantiates a node client
func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (Client, error) {
client := transport.NewClient(identity)
node := &Node{
dht: dht,
self: self,
tc: client,
cache: NewConnectionPool(),
pool: NewConnectionPool(identity),
}
node.cache.Init()
node.pool.Init()
return node, nil
}
// Client is the Node client communication interface
type Client interface {
Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error)
Ping(ctx context.Context, to pb.Node) (bool, error)
Disconnect() error
}

View File

@ -4,10 +4,15 @@
package node
import (
"context"
"sync"
"github.com/zeebo/errs"
"google.golang.org/grpc"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/provider"
"storj.io/storj/pkg/transport"
"storj.io/storj/pkg/utils"
)
@ -16,21 +21,31 @@ var Error = errs.Class("connection pool error")
// ConnectionPool is the in memory pool of node connections
type ConnectionPool struct {
tc transport.Client
mu sync.RWMutex
cache map[string]interface{}
items map[string]*Conn
}
// Conn is the connection that is stored in the connection pool
type Conn struct {
addr string
dial sync.Once
client pb.NodesClient
grpc *grpc.ClientConn
err error
}
// NewConn intitalizes a new Conn struct with the provided address, but does not iniate a connection
func NewConn(addr string) *Conn { return &Conn{addr: addr} }
// NewConnectionPool initializes a new in memory pool
func NewConnectionPool() *ConnectionPool {
return &ConnectionPool{}
}
// Add takes a node ID as the key and a node client as the value to store
func (pool *ConnectionPool) Add(key string, value interface{}) error {
pool.mu.Lock()
defer pool.mu.Unlock()
pool.cache[key] = value
return nil
func NewConnectionPool(identity *provider.FullIdentity) *ConnectionPool {
return &ConnectionPool{
tc: transport.NewClient(identity),
items: make(map[string]*Conn),
mu: sync.RWMutex{},
}
}
// Get retrieves a node connection with the provided nodeID
@ -38,43 +53,71 @@ func (pool *ConnectionPool) Add(key string, value interface{}) error {
func (pool *ConnectionPool) Get(key string) (interface{}, error) {
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.cache[key], nil
i, ok := pool.items[key]
if !ok {
return nil, nil
}
return i, nil
}
// Remove deletes a connection associated with the provided NodeID
func (pool *ConnectionPool) Remove(key string) error {
// Disconnect deletes a connection associated with the provided NodeID
func (pool *ConnectionPool) Disconnect(key string) error {
pool.mu.Lock()
defer pool.mu.Unlock()
pool.cache[key] = nil
i, ok := pool.items[key]
if !ok {
return nil
}
delete(pool.items, key)
return i.grpc.Close()
}
// Disconnect closes the connection to the node and removes it from the connection pool
func (pool *ConnectionPool) Disconnect() error {
var err error
var errs []error
for k, v := range pool.cache {
conn, ok := v.(interface{ Close() error })
// Dial connects to the node with the given ID and Address returning a gRPC Node Client
func (pool *ConnectionPool) Dial(ctx context.Context, n *pb.Node) (pb.NodesClient, error) {
id := n.GetId()
pool.mu.Lock()
conn, ok := pool.items[id]
if !ok {
err = Error.New("connection pool value not a grpc client connection")
errs = append(errs, err)
continue
conn = NewConn(n.GetAddress().Address)
pool.items[id] = conn
}
err = conn.Close()
if err != nil {
errs = append(errs, Error.Wrap(err))
continue
pool.mu.Unlock()
conn.dial.Do(func() {
conn.grpc, conn.err = pool.tc.DialNode(ctx, n)
if conn.err != nil {
return
}
err = pool.Remove(k)
if err != nil {
conn.client = pb.NewNodesClient(conn.grpc)
})
if conn.err != nil {
return nil, conn.err
}
return conn.client, nil
}
// DisconnectAll closes all connections nodes and removes them from the connection pool
func (pool *ConnectionPool) DisconnectAll() error {
errs := []error{}
for k := range pool.items {
if err := pool.Disconnect(k); err != nil {
errs = append(errs, Error.Wrap(err))
continue
}
}
return utils.CombineErrors(errs...)
}
// Init initializes the cache
func (pool *ConnectionPool) Init() {
pool.cache = make(map[string]interface{})
pool.items = make(map[string]*Conn)
}

View File

@ -4,32 +4,31 @@
package node
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"storj.io/storj/pkg/pb"
)
type TestFoo struct {
called string
}
func TestGet(t *testing.T) {
cases := []struct {
pool *ConnectionPool
key string
expected TestFoo
expected Conn
expectedError error
}{
{
pool: func() *ConnectionPool {
p := NewConnectionPool()
p := NewConnectionPool(newTestIdentity(t))
p.Init()
assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"}))
p.items["foo"] = &Conn{addr: "foo"}
return p
}(),
key: "foo",
expected: TestFoo{called: "hoot"},
expected: Conn{addr: "foo"},
expectedError: nil,
},
}
@ -38,42 +37,16 @@ func TestGet(t *testing.T) {
v := &cases[i]
test, err := v.pool.Get(v.key)
assert.Equal(t, v.expectedError, err)
assert.Equal(t, v.expected, test)
assert.Equal(t, v.expected.addr, test.(*Conn).addr)
}
}
func TestAdd(t *testing.T) {
cases := []struct {
pool ConnectionPool
key string
value TestFoo
expected TestFoo
expectedError error
}{
{
pool: ConnectionPool{
mu: sync.RWMutex{},
cache: map[string]interface{}{}},
key: "foo",
value: TestFoo{called: "hoot"},
expected: TestFoo{called: "hoot"},
expectedError: nil,
},
}
func TestDisconnect(t *testing.T) {
for i := range cases {
v := &cases[i]
err := v.pool.Add(v.key, v.value)
assert.Equal(t, v.expectedError, err)
test, err := v.pool.Get(v.key)
assert.Equal(t, v.expectedError, err)
assert.Equal(t, v.expected, test)
}
}
func TestRemove(t *testing.T) {
conn, err := grpc.Dial("127.0.0.1:0", grpc.WithInsecure())
assert.NoError(t, err)
// gc.Close = func() error { return nil }
cases := []struct {
pool ConnectionPool
key string
@ -83,7 +56,8 @@ func TestRemove(t *testing.T) {
{
pool: ConnectionPool{
mu: sync.RWMutex{},
cache: map[string]interface{}{"foo": TestFoo{called: "hoot"}}},
items: map[string]*Conn{"foo": &Conn{grpc: conn}},
},
key: "foo",
expected: nil,
expectedError: nil,
@ -92,7 +66,7 @@ func TestRemove(t *testing.T) {
for i := range cases {
v := &cases[i]
err := v.pool.Remove(v.key)
err := v.pool.Disconnect(v.key)
assert.Equal(t, v.expectedError, err)
test, err := v.pool.Get(v.key)
@ -101,3 +75,38 @@ func TestRemove(t *testing.T) {
assert.Equal(t, v.expected, test)
}
}
func TestDial(t *testing.T) {
cases := []struct {
pool *ConnectionPool
node *pb.Node
expectedError error
expected *Conn
}{
{
pool: NewConnectionPool(newTestIdentity(t)),
node: &pb.Node{Id: "foo", Address: &pb.NodeAddress{Address: "127.0.0.1:0"}},
expected: nil,
expectedError: nil,
},
}
for _, v := range cases {
wg := sync.WaitGroup{}
wg.Add(4)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
wg.Wait()
}
}
func testDial(t *testing.T, wg *sync.WaitGroup, p *ConnectionPool, n *pb.Node, eerr error) {
defer wg.Done()
ctx := context.Background()
actual, err := p.Dial(ctx, n)
assert.Equal(t, eerr, err)
assert.NotNil(t, actual)
}

View File

@ -5,64 +5,54 @@ package node
import (
"context"
"log"
"google.golang.org/grpc"
"storj.io/storj/pkg/dht"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/transport"
)
// Node is the storj definition for a node in the network
type Node struct {
dht dht.DHT
self pb.Node
tc transport.Client
cache *ConnectionPool
pool *ConnectionPool
}
// Lookup queries nodes looking for a particular node in the network
func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) {
v, err := n.cache.Get(to.GetId())
c, err := n.pool.Dial(ctx, &to)
if err != nil {
return nil, err
return nil, NodeClientErr.Wrap(err)
}
var conn *grpc.ClientConn
if c, ok := v.(*grpc.ClientConn); ok {
conn = c
} else {
c, err := n.tc.DialNode(ctx, &to)
if err != nil {
return nil, err
}
if err := n.cache.Add(to.GetId(), c); err != nil {
log.Printf("Error %s occurred adding %s to cache", err, to.GetId())
}
conn = c
}
c := pb.NewNodesClient(conn)
resp, err := c.Query(ctx, &pb.QueryRequest{Limit: 20, Sender: &n.self, Target: &find, Pingback: true})
if err != nil {
return nil, err
return nil, NodeClientErr.Wrap(err)
}
rt, err := n.dht.GetRoutingTable(ctx)
if err != nil {
return nil, err
return nil, NodeClientErr.Wrap(err)
}
if err := rt.ConnectionSuccess(&to); err != nil {
return nil, err
return nil, NodeClientErr.Wrap(err)
}
return resp.Response, nil
}
// Disconnect closes connections within the cache
func (n *Node) Disconnect() error {
return n.cache.Disconnect()
// Ping attempts to establish a connection with a node to verify it is alive
func (n *Node) Ping(ctx context.Context, to pb.Node) (bool, error) {
_, err := n.pool.Dial(ctx, &to)
if err != nil {
return false, NodeClientErr.Wrap(err)
}
return true, nil
}
// Disconnect closes all connections within the pool
func (n *Node) Disconnect() error {
return n.pool.DisconnectAll()
}

View File

@ -40,7 +40,8 @@ func TestLookup(t *testing.T) {
v.to = pb.Node{Id: NewNodeID(t), Address: &pb.NodeAddress{Address: lis.Addr().String()}}
srv, mock, err := newTestServer(ctx)
id := newTestIdentity(t)
srv, mock, err := newTestServer(ctx, &mockNodeServer{queryCalled: 0}, id)
assert.NoError(t, err)
go func() { assert.NoError(t, srv.Serve(lis)) }()
defer srv.Stop()
@ -62,30 +63,63 @@ func TestLookup(t *testing.T) {
_, err = nc.Lookup(ctx, v.to, v.find)
assert.Equal(t, v.expectedErr, err)
assert.Equal(t, 1, mock.queryCalled)
assert.Equal(t, 1, mock.(*mockNodeServer).queryCalled)
}
}
func newTestServer(ctx context.Context) (*grpc.Server, *mockNodeServer, error) {
ca, err := provider.NewCA(ctx, 12, 4)
if err != nil {
return nil, nil, err
func TestPing(t *testing.T) {
ctx := context.Background()
cases := []struct {
self pb.Node
toID string
toIdentity *provider.FullIdentity
expectedErr error
}{
{
self: pb.Node{Id: "hello", Address: &pb.NodeAddress{Address: ":7070"}},
toID: "",
toIdentity: newTestIdentity(t),
expectedErr: nil,
},
}
identity, err := ca.NewIdentity()
if err != nil {
return nil, nil, err
for _, v := range cases {
lis, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
// new mock DHT for node client
ctrl := gomock.NewController(t)
mdht := mock_dht.NewMockDHT(ctrl)
// set up a node server
srv := NewServer(mdht)
msrv, _, err := newTestServer(ctx, srv, v.toIdentity)
assert.NoError(t, err)
// start gRPC server
go func() { assert.NoError(t, msrv.Serve(lis)) }()
defer msrv.Stop()
nc, err := NewNodeClient(v.toIdentity, v.self, mdht)
assert.NoError(t, err)
id := ID(v.toIdentity.ID)
ok, err := nc.Ping(ctx, pb.Node{Id: id.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}})
assert.Equal(t, v.expectedErr, err)
assert.Equal(t, ok, true)
}
}
func newTestServer(ctx context.Context, ns pb.NodesServer, identity *provider.FullIdentity) (*grpc.Server, pb.NodesServer, error) {
identOpt, err := identity.ServerOption()
if err != nil {
return nil, nil, err
}
grpcServer := grpc.NewServer(identOpt)
mn := &mockNodeServer{queryCalled: 0}
pb.RegisterNodesServer(grpcServer, mn)
pb.RegisterNodesServer(grpcServer, ns)
return grpcServer, mn, nil
return grpcServer, ns, nil
}
@ -106,3 +140,12 @@ func NewNodeID(t *testing.T) string {
return id.String()
}
func newTestIdentity(t *testing.T) *provider.FullIdentity {
ca, err := provider.NewCA(ctx, 12, 4)
assert.NoError(t, err)
identity, err := ca.NewIdentity()
assert.NoError(t, err)
return identity
}

View File

@ -301,6 +301,7 @@ func (fi *FullIdentity) ServerOption(pcvFuncs ...peertls.PeerCertVerificationFun
// DialOption returns a grpc `DialOption` for making outgoing connections
// to the node with this peer identity
func (fi *FullIdentity) DialOption() (grpc.DialOption, error) {
// TODO(coyle): add ID
ch := [][]byte{fi.Leaf.Raw, fi.CA.Raw}
c, err := peertls.TLSCert(ch, fi.Leaf, fi.Key)
if err != nil {
@ -312,6 +313,10 @@ func (fi *FullIdentity) DialOption() (grpc.DialOption, error) {
InsecureSkipVerify: true,
VerifyPeerCertificate: peertls.VerifyPeerFunc(
peertls.VerifyPeerCertChains,
func(_ [][]byte, parsedChains [][]*x509.Certificate) error {
return nil
},
// TODO(coyle): Check that the ID of the node we are dialing is the owner of the certificate.
),
}

View File

@ -29,7 +29,7 @@ func (o *Transport) DialNode(ctx context.Context, node *pb.Node) (conn *grpc.Cli
if node.Address == nil || node.Address.Address == "" {
return nil, Error.New("no address")
}
// TODO(coyle): pass ID
dialOpt, err := o.identity.DialOption()
if err != nil {
return nil, err