Added ping support to node client (#491)
* added ping support to node client * Added tests to Ping * Added connection creation responsibility to Connection Pool
This commit is contained in:
parent
79e1ed5d1f
commit
3b7b2afb1f
@ -189,8 +189,16 @@ func (k *Kademlia) lookup(ctx context.Context, target dht.NodeID, opts lookupOpt
|
||||
|
||||
// Ping checks that the provided node is still accessible on the network
|
||||
func (k *Kademlia) Ping(ctx context.Context, node pb.Node) (pb.Node, error) {
|
||||
// TODO(coyle)
|
||||
return pb.Node{}, nil
|
||||
ok, err := k.nodeClient.Ping(ctx, node)
|
||||
if err != nil {
|
||||
return pb.Node{}, NodeErr.Wrap(err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return pb.Node{}, NodeErr.New("Failed pinging node")
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// FindNode looks up the provided NodeID first in the local Node, and if it is not found
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"storj.io/storj/pkg/dht"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/provider"
|
||||
"storj.io/storj/pkg/transport"
|
||||
)
|
||||
|
||||
//NodeClientErr is the class for all errors pertaining to node client operations
|
||||
@ -18,19 +17,20 @@ var NodeClientErr = errs.Class("node client error")
|
||||
|
||||
// NewNodeClient instantiates a node client
|
||||
func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (Client, error) {
|
||||
client := transport.NewClient(identity)
|
||||
node := &Node{
|
||||
dht: dht,
|
||||
self: self,
|
||||
tc: client,
|
||||
cache: NewConnectionPool(),
|
||||
pool: NewConnectionPool(identity),
|
||||
}
|
||||
node.cache.Init()
|
||||
|
||||
node.pool.Init()
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// Client is the Node client communication interface
|
||||
type Client interface {
|
||||
Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error)
|
||||
Ping(ctx context.Context, to pb.Node) (bool, error)
|
||||
Disconnect() error
|
||||
}
|
||||
|
@ -4,10 +4,15 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/provider"
|
||||
"storj.io/storj/pkg/transport"
|
||||
"storj.io/storj/pkg/utils"
|
||||
)
|
||||
|
||||
@ -16,21 +21,31 @@ var Error = errs.Class("connection pool error")
|
||||
|
||||
// ConnectionPool is the in memory pool of node connections
|
||||
type ConnectionPool struct {
|
||||
tc transport.Client
|
||||
mu sync.RWMutex
|
||||
cache map[string]interface{}
|
||||
items map[string]*Conn
|
||||
}
|
||||
|
||||
// Conn is the connection that is stored in the connection pool
|
||||
type Conn struct {
|
||||
addr string
|
||||
|
||||
dial sync.Once
|
||||
client pb.NodesClient
|
||||
grpc *grpc.ClientConn
|
||||
err error
|
||||
}
|
||||
|
||||
// NewConn intitalizes a new Conn struct with the provided address, but does not iniate a connection
|
||||
func NewConn(addr string) *Conn { return &Conn{addr: addr} }
|
||||
|
||||
// NewConnectionPool initializes a new in memory pool
|
||||
func NewConnectionPool() *ConnectionPool {
|
||||
return &ConnectionPool{}
|
||||
}
|
||||
|
||||
// Add takes a node ID as the key and a node client as the value to store
|
||||
func (pool *ConnectionPool) Add(key string, value interface{}) error {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
pool.cache[key] = value
|
||||
return nil
|
||||
func NewConnectionPool(identity *provider.FullIdentity) *ConnectionPool {
|
||||
return &ConnectionPool{
|
||||
tc: transport.NewClient(identity),
|
||||
items: make(map[string]*Conn),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a node connection with the provided nodeID
|
||||
@ -38,43 +53,71 @@ func (pool *ConnectionPool) Add(key string, value interface{}) error {
|
||||
func (pool *ConnectionPool) Get(key string) (interface{}, error) {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
return pool.cache[key], nil
|
||||
|
||||
i, ok := pool.items[key]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// Remove deletes a connection associated with the provided NodeID
|
||||
func (pool *ConnectionPool) Remove(key string) error {
|
||||
// Disconnect deletes a connection associated with the provided NodeID
|
||||
func (pool *ConnectionPool) Disconnect(key string) error {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
pool.cache[key] = nil
|
||||
|
||||
i, ok := pool.items[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(pool.items, key)
|
||||
|
||||
return i.grpc.Close()
|
||||
}
|
||||
|
||||
// Disconnect closes the connection to the node and removes it from the connection pool
|
||||
func (pool *ConnectionPool) Disconnect() error {
|
||||
var err error
|
||||
var errs []error
|
||||
for k, v := range pool.cache {
|
||||
conn, ok := v.(interface{ Close() error })
|
||||
// Dial connects to the node with the given ID and Address returning a gRPC Node Client
|
||||
func (pool *ConnectionPool) Dial(ctx context.Context, n *pb.Node) (pb.NodesClient, error) {
|
||||
id := n.GetId()
|
||||
pool.mu.Lock()
|
||||
conn, ok := pool.items[id]
|
||||
if !ok {
|
||||
err = Error.New("connection pool value not a grpc client connection")
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
conn = NewConn(n.GetAddress().Address)
|
||||
pool.items[id] = conn
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
errs = append(errs, Error.Wrap(err))
|
||||
continue
|
||||
pool.mu.Unlock()
|
||||
|
||||
conn.dial.Do(func() {
|
||||
conn.grpc, conn.err = pool.tc.DialNode(ctx, n)
|
||||
if conn.err != nil {
|
||||
return
|
||||
}
|
||||
err = pool.Remove(k)
|
||||
if err != nil {
|
||||
|
||||
conn.client = pb.NewNodesClient(conn.grpc)
|
||||
})
|
||||
|
||||
if conn.err != nil {
|
||||
return nil, conn.err
|
||||
}
|
||||
|
||||
return conn.client, nil
|
||||
}
|
||||
|
||||
// DisconnectAll closes all connections nodes and removes them from the connection pool
|
||||
func (pool *ConnectionPool) DisconnectAll() error {
|
||||
errs := []error{}
|
||||
for k := range pool.items {
|
||||
if err := pool.Disconnect(k); err != nil {
|
||||
errs = append(errs, Error.Wrap(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return utils.CombineErrors(errs...)
|
||||
}
|
||||
|
||||
// Init initializes the cache
|
||||
func (pool *ConnectionPool) Init() {
|
||||
pool.cache = make(map[string]interface{})
|
||||
pool.items = make(map[string]*Conn)
|
||||
}
|
||||
|
@ -4,32 +4,31 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"storj.io/storj/pkg/pb"
|
||||
)
|
||||
|
||||
type TestFoo struct {
|
||||
called string
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
cases := []struct {
|
||||
pool *ConnectionPool
|
||||
key string
|
||||
expected TestFoo
|
||||
expected Conn
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
pool: func() *ConnectionPool {
|
||||
p := NewConnectionPool()
|
||||
p := NewConnectionPool(newTestIdentity(t))
|
||||
p.Init()
|
||||
assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"}))
|
||||
p.items["foo"] = &Conn{addr: "foo"}
|
||||
return p
|
||||
}(),
|
||||
key: "foo",
|
||||
expected: TestFoo{called: "hoot"},
|
||||
expected: Conn{addr: "foo"},
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
@ -38,42 +37,16 @@ func TestGet(t *testing.T) {
|
||||
v := &cases[i]
|
||||
test, err := v.pool.Get(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
assert.Equal(t, v.expected, test)
|
||||
|
||||
assert.Equal(t, v.expected.addr, test.(*Conn).addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
cases := []struct {
|
||||
pool ConnectionPool
|
||||
key string
|
||||
value TestFoo
|
||||
expected TestFoo
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
pool: ConnectionPool{
|
||||
mu: sync.RWMutex{},
|
||||
cache: map[string]interface{}{}},
|
||||
key: "foo",
|
||||
value: TestFoo{called: "hoot"},
|
||||
expected: TestFoo{called: "hoot"},
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
func TestDisconnect(t *testing.T) {
|
||||
|
||||
for i := range cases {
|
||||
v := &cases[i]
|
||||
err := v.pool.Add(v.key, v.value)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
test, err := v.pool.Get(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
assert.Equal(t, v.expected, test)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
conn, err := grpc.Dial("127.0.0.1:0", grpc.WithInsecure())
|
||||
assert.NoError(t, err)
|
||||
// gc.Close = func() error { return nil }
|
||||
cases := []struct {
|
||||
pool ConnectionPool
|
||||
key string
|
||||
@ -83,7 +56,8 @@ func TestRemove(t *testing.T) {
|
||||
{
|
||||
pool: ConnectionPool{
|
||||
mu: sync.RWMutex{},
|
||||
cache: map[string]interface{}{"foo": TestFoo{called: "hoot"}}},
|
||||
items: map[string]*Conn{"foo": &Conn{grpc: conn}},
|
||||
},
|
||||
key: "foo",
|
||||
expected: nil,
|
||||
expectedError: nil,
|
||||
@ -92,7 +66,7 @@ func TestRemove(t *testing.T) {
|
||||
|
||||
for i := range cases {
|
||||
v := &cases[i]
|
||||
err := v.pool.Remove(v.key)
|
||||
err := v.pool.Disconnect(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
test, err := v.pool.Get(v.key)
|
||||
@ -101,3 +75,38 @@ func TestRemove(t *testing.T) {
|
||||
assert.Equal(t, v.expected, test)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
cases := []struct {
|
||||
pool *ConnectionPool
|
||||
node *pb.Node
|
||||
expectedError error
|
||||
expected *Conn
|
||||
}{
|
||||
{
|
||||
pool: NewConnectionPool(newTestIdentity(t)),
|
||||
node: &pb.Node{Id: "foo", Address: &pb.NodeAddress{Address: "127.0.0.1:0"}},
|
||||
expected: nil,
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range cases {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(4)
|
||||
go testDial(t, &wg, v.pool, v.node, v.expectedError)
|
||||
go testDial(t, &wg, v.pool, v.node, v.expectedError)
|
||||
go testDial(t, &wg, v.pool, v.node, v.expectedError)
|
||||
go testDial(t, &wg, v.pool, v.node, v.expectedError)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func testDial(t *testing.T, wg *sync.WaitGroup, p *ConnectionPool, n *pb.Node, eerr error) {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
actual, err := p.Dial(ctx, n)
|
||||
assert.Equal(t, eerr, err)
|
||||
assert.NotNil(t, actual)
|
||||
}
|
||||
|
@ -5,64 +5,54 @@ package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"storj.io/storj/pkg/dht"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/transport"
|
||||
)
|
||||
|
||||
// Node is the storj definition for a node in the network
|
||||
type Node struct {
|
||||
dht dht.DHT
|
||||
self pb.Node
|
||||
tc transport.Client
|
||||
cache *ConnectionPool
|
||||
pool *ConnectionPool
|
||||
}
|
||||
|
||||
// Lookup queries nodes looking for a particular node in the network
|
||||
func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) {
|
||||
v, err := n.cache.Get(to.GetId())
|
||||
c, err := n.pool.Dial(ctx, &to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, NodeClientErr.Wrap(err)
|
||||
}
|
||||
|
||||
var conn *grpc.ClientConn
|
||||
if c, ok := v.(*grpc.ClientConn); ok {
|
||||
conn = c
|
||||
} else {
|
||||
c, err := n.tc.DialNode(ctx, &to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.cache.Add(to.GetId(), c); err != nil {
|
||||
log.Printf("Error %s occurred adding %s to cache", err, to.GetId())
|
||||
}
|
||||
conn = c
|
||||
}
|
||||
|
||||
c := pb.NewNodesClient(conn)
|
||||
resp, err := c.Query(ctx, &pb.QueryRequest{Limit: 20, Sender: &n.self, Target: &find, Pingback: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, NodeClientErr.Wrap(err)
|
||||
}
|
||||
|
||||
rt, err := n.dht.GetRoutingTable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, NodeClientErr.Wrap(err)
|
||||
}
|
||||
|
||||
if err := rt.ConnectionSuccess(&to); err != nil {
|
||||
return nil, err
|
||||
return nil, NodeClientErr.Wrap(err)
|
||||
|
||||
}
|
||||
|
||||
return resp.Response, nil
|
||||
}
|
||||
|
||||
// Disconnect closes connections within the cache
|
||||
func (n *Node) Disconnect() error {
|
||||
return n.cache.Disconnect()
|
||||
// Ping attempts to establish a connection with a node to verify it is alive
|
||||
func (n *Node) Ping(ctx context.Context, to pb.Node) (bool, error) {
|
||||
_, err := n.pool.Dial(ctx, &to)
|
||||
if err != nil {
|
||||
return false, NodeClientErr.Wrap(err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Disconnect closes all connections within the pool
|
||||
func (n *Node) Disconnect() error {
|
||||
return n.pool.DisconnectAll()
|
||||
}
|
||||
|
@ -40,7 +40,8 @@ func TestLookup(t *testing.T) {
|
||||
|
||||
v.to = pb.Node{Id: NewNodeID(t), Address: &pb.NodeAddress{Address: lis.Addr().String()}}
|
||||
|
||||
srv, mock, err := newTestServer(ctx)
|
||||
id := newTestIdentity(t)
|
||||
srv, mock, err := newTestServer(ctx, &mockNodeServer{queryCalled: 0}, id)
|
||||
assert.NoError(t, err)
|
||||
go func() { assert.NoError(t, srv.Serve(lis)) }()
|
||||
defer srv.Stop()
|
||||
@ -62,30 +63,63 @@ func TestLookup(t *testing.T) {
|
||||
|
||||
_, err = nc.Lookup(ctx, v.to, v.find)
|
||||
assert.Equal(t, v.expectedErr, err)
|
||||
assert.Equal(t, 1, mock.queryCalled)
|
||||
assert.Equal(t, 1, mock.(*mockNodeServer).queryCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(ctx context.Context) (*grpc.Server, *mockNodeServer, error) {
|
||||
ca, err := provider.NewCA(ctx, 12, 4)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
func TestPing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cases := []struct {
|
||||
self pb.Node
|
||||
toID string
|
||||
toIdentity *provider.FullIdentity
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
self: pb.Node{Id: "hello", Address: &pb.NodeAddress{Address: ":7070"}},
|
||||
toID: "",
|
||||
toIdentity: newTestIdentity(t),
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
identity, err := ca.NewIdentity()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
for _, v := range cases {
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
// new mock DHT for node client
|
||||
ctrl := gomock.NewController(t)
|
||||
mdht := mock_dht.NewMockDHT(ctrl)
|
||||
// set up a node server
|
||||
srv := NewServer(mdht)
|
||||
|
||||
msrv, _, err := newTestServer(ctx, srv, v.toIdentity)
|
||||
assert.NoError(t, err)
|
||||
// start gRPC server
|
||||
|
||||
go func() { assert.NoError(t, msrv.Serve(lis)) }()
|
||||
defer msrv.Stop()
|
||||
|
||||
nc, err := NewNodeClient(v.toIdentity, v.self, mdht)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := ID(v.toIdentity.ID)
|
||||
ok, err := nc.Ping(ctx, pb.Node{Id: id.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}})
|
||||
assert.Equal(t, v.expectedErr, err)
|
||||
assert.Equal(t, ok, true)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(ctx context.Context, ns pb.NodesServer, identity *provider.FullIdentity) (*grpc.Server, pb.NodesServer, error) {
|
||||
identOpt, err := identity.ServerOption()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(identOpt)
|
||||
mn := &mockNodeServer{queryCalled: 0}
|
||||
|
||||
pb.RegisterNodesServer(grpcServer, mn)
|
||||
pb.RegisterNodesServer(grpcServer, ns)
|
||||
|
||||
return grpcServer, mn, nil
|
||||
return grpcServer, ns, nil
|
||||
|
||||
}
|
||||
|
||||
@ -106,3 +140,12 @@ func NewNodeID(t *testing.T) string {
|
||||
|
||||
return id.String()
|
||||
}
|
||||
|
||||
func newTestIdentity(t *testing.T) *provider.FullIdentity {
|
||||
ca, err := provider.NewCA(ctx, 12, 4)
|
||||
assert.NoError(t, err)
|
||||
identity, err := ca.NewIdentity()
|
||||
assert.NoError(t, err)
|
||||
|
||||
return identity
|
||||
}
|
||||
|
@ -301,6 +301,7 @@ func (fi *FullIdentity) ServerOption(pcvFuncs ...peertls.PeerCertVerificationFun
|
||||
// DialOption returns a grpc `DialOption` for making outgoing connections
|
||||
// to the node with this peer identity
|
||||
func (fi *FullIdentity) DialOption() (grpc.DialOption, error) {
|
||||
// TODO(coyle): add ID
|
||||
ch := [][]byte{fi.Leaf.Raw, fi.CA.Raw}
|
||||
c, err := peertls.TLSCert(ch, fi.Leaf, fi.Key)
|
||||
if err != nil {
|
||||
@ -312,6 +313,10 @@ func (fi *FullIdentity) DialOption() (grpc.DialOption, error) {
|
||||
InsecureSkipVerify: true,
|
||||
VerifyPeerCertificate: peertls.VerifyPeerFunc(
|
||||
peertls.VerifyPeerCertChains,
|
||||
func(_ [][]byte, parsedChains [][]*x509.Certificate) error {
|
||||
return nil
|
||||
},
|
||||
// TODO(coyle): Check that the ID of the node we are dialing is the owner of the certificate.
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ func (o *Transport) DialNode(ctx context.Context, node *pb.Node) (conn *grpc.Cli
|
||||
if node.Address == nil || node.Address.Address == "" {
|
||||
return nil, Error.New("no address")
|
||||
}
|
||||
|
||||
// TODO(coyle): pass ID
|
||||
dialOpt, err := o.identity.DialOption()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
Loading…
Reference in New Issue
Block a user