Disconnects nodeclient, routing table dbs when done with kademlia (#507)

* disconnect from nodeclient

* cleanup connections in tests

* kademlia disconnects from nodeclient

* updating disconnect method for mocks

* creates separate disconnect and removeAll methods for tests

* adds init to connection pool

* fix folder cleanup and disconnect

* creates and cleans up test db files and disconnects kad

* removes db/.keep

* includes disconnect within cleanup methods

* creates public init method on connection pool to handle mutex copy issues

* remove all after disconnect

* pair creation and destruction

* checks disconnect error

* remove ctx

* fixes mock kad
This commit is contained in:
Jennifer Li Johnson 2018-10-26 10:07:02 -04:00 committed by Dennis Coyle
parent df1f7a6214
commit 8d779d3d3e
10 changed files with 133 additions and 65 deletions

3
.gitignore vendored
View File

@ -7,8 +7,6 @@
*.so *.so
*.dylib *.dylib
*.db *.db
pkg/kademlia/db/*.db
db
# Test binary, build with `go test -c` # Test binary, build with `go test -c`
*.test *.test
@ -35,7 +33,6 @@ protos/google/*
# Test redis log and snapshot files # Test redis log and snapshot files
*test_redis-server.log *test_redis-server.log
*dump.rdb *dump.rdb
*test_bolt.db
*.coverprofile *.coverprofile
*.log *.log

View File

View File

@ -6,13 +6,11 @@ package mock_dht
import ( import (
context "context" context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
dht "storj.io/storj/pkg/dht" dht "storj.io/storj/pkg/dht"
"storj.io/storj/pkg/pb" pb "storj.io/storj/pkg/pb"
time "time"
) )
// MockDHT is a mock of DHT interface // MockDHT is a mock of DHT interface

View File

@ -6,6 +6,8 @@ package kademlia
import ( import (
"context" "context"
"storj.io/storj/pkg/utils"
"github.com/zeebo/errs" "github.com/zeebo/errs"
monkit "gopkg.in/spacemonkeygo/monkit.v2" monkit "gopkg.in/spacemonkeygo/monkit.v2"
@ -76,7 +78,7 @@ func (c Config) Run(ctx context.Context, server *provider.Provider) (
if err != nil { if err != nil {
return err return err
} }
defer func() { _ = kad.Disconnect() }() defer func() { err = utils.CombineErrors(err, kad.Disconnect()) }()
mn := node.NewServer(kad) mn := node.NewServer(kad)
pb.RegisterNodesServer(server.GRPC(), mn) pb.RegisterNodesServer(server.GRPC(), mn)

View File

@ -111,8 +111,8 @@ func NewKademlia(id dht.NodeID, bootstrapNodes []pb.Node, address string, identi
// Disconnect safely closes connections to the Kademlia network // Disconnect safely closes connections to the Kademlia network
func (k *Kademlia) Disconnect() error { func (k *Kademlia) Disconnect() error {
return utils.CombineErrors( return utils.CombineErrors(
k.nodeClient.Disconnect(),
k.routingTable.Close(), k.routingTable.Close(),
// TODO: close connections
) )
} }

View File

@ -6,8 +6,11 @@ package kademlia
import ( import (
"context" "context"
"io/ioutil"
"net" "net"
"os" "os"
"path/filepath"
"strconv"
"testing" "testing"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -38,12 +41,13 @@ func newTestIdentity() (*provider.FullIdentity, error) {
} }
func TestNewKademlia(t *testing.T) { func TestNewKademlia(t *testing.T) {
rootdir, cleanup := mktempdir(t, "kademlia")
defer cleanup()
cases := []struct { cases := []struct {
id dht.NodeID id dht.NodeID
bn []pb.Node bn []pb.Node
addr string addr string
expectedErr error expectedErr error
setup func() error
}{ }{
{ {
id: func() *node.ID { id: func() *node.ID {
@ -52,9 +56,8 @@ func TestNewKademlia(t *testing.T) {
n := node.ID(id.ID) n := node.ID(id.ID)
return &n return &n
}(), }(),
bn: []pb.Node{pb.Node{Id: "foo"}}, bn: []pb.Node{pb.Node{Id: "foo"}},
addr: "127.0.0.1:8080", addr: "127.0.0.1:8080",
setup: func() error { return nil },
}, },
{ {
id: func() *node.ID { id: func() *node.ID {
@ -63,25 +66,30 @@ func TestNewKademlia(t *testing.T) {
n := node.ID(id.ID) n := node.ID(id.ID)
return &n return &n
}(), }(),
bn: []pb.Node{pb.Node{Id: "foo"}}, bn: []pb.Node{pb.Node{Id: "foo"}},
addr: "127.0.0.1:8080", addr: "127.0.0.1:8080",
setup: func() error { return os.RemoveAll("db") },
}, },
} }
for _, v := range cases { for i, v := range cases {
assert.NoError(t, v.setup()) dir := filepath.Join(rootdir, strconv.Itoa(i))
kc := kadconfig() kc := kadconfig()
ca, err := provider.NewCA(context.Background(), 12, 4) ca, err := provider.NewCA(context.Background(), 12, 4)
assert.NoError(t, err) assert.NoError(t, err)
identity, err := ca.NewIdentity() identity, err := ca.NewIdentity()
assert.NoError(t, err) assert.NoError(t, err)
actual, err := NewKademlia(v.id, v.bn, v.addr, identity, "db", kc)
kad, err := NewKademlia(v.id, v.bn, v.addr, identity, dir, kc)
assert.NoError(t, err)
assert.Equal(t, v.expectedErr, err) assert.Equal(t, v.expectedErr, err)
assert.Equal(t, actual.bootstrapNodes, v.bn) assert.Equal(t, kad.bootstrapNodes, v.bn)
assert.NotNil(t, actual.nodeClient) assert.NotNil(t, kad.nodeClient)
assert.NotNil(t, actual.routingTable) assert.NotNil(t, kad.routingTable)
assert.NoError(t, kad.Disconnect())
} }
} }
func TestLookup(t *testing.T) { func TestLookup(t *testing.T) {
@ -92,9 +100,11 @@ func TestLookup(t *testing.T) {
kc := kadconfig() kc := kadconfig()
srv, mns := newTestServer([]*pb.Node{&pb.Node{Id: "foo"}}) srv, mns := newTestServer([]*pb.Node{&pb.Node{Id: "foo"}})
go func() { _ = srv.Serve(lis) }() go func() { assert.NoError(t, srv.Serve(lis)) }()
defer srv.Stop() defer srv.Stop()
dir, cleanup := mktempdir(t, "kademlia")
defer cleanup()
k := func() *Kademlia { k := func() *Kademlia {
// make new identity // make new identity
fid, err := newTestIdentity() fid, err := newTestIdentity()
@ -108,41 +118,39 @@ func TestLookup(t *testing.T) {
assert.NotEqual(t, id, id2) assert.NotEqual(t, id, id2)
kid := dht.NodeID(fid.ID) kid := dht.NodeID(fid.ID)
k, err := NewKademlia(kid, []pb.Node{pb.Node{Id: id2.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}}}, lis.Addr().String(), fid, "db", kc) k, err := NewKademlia(kid, []pb.Node{pb.Node{Id: id2.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}}}, lis.Addr().String(), fid, dir, kc)
assert.NoError(t, err) assert.NoError(t, err)
return k return k
}() }()
defer func() {
assert.NoError(t, k.Disconnect())
}()
cases := []struct { cases := []struct {
k *Kademlia
target dht.NodeID target dht.NodeID
opts lookupOpts opts lookupOpts
expected *pb.Node expected *pb.Node
expectedErr error expectedErr error
}{ }{
{ {target: func() *node.ID {
k: k, fid, err := newTestIdentity()
target: func() *node.ID { id := dht.NodeID(fid.ID)
fid, err := newTestIdentity() nid := node.ID(fid.ID)
id := dht.NodeID(fid.ID) assert.NoError(t, err)
nid := node.ID(fid.ID) mns.returnValue = []*pb.Node{&pb.Node{Id: id.String(), Address: &pb.NodeAddress{Address: addr}}}
assert.NoError(t, err) return &nid
mns.returnValue = []*pb.Node{&pb.Node{Id: id.String(), Address: &pb.NodeAddress{Address: addr}}} }(),
return &nid
}(),
opts: lookupOpts{amount: 5}, opts: lookupOpts{amount: 5},
expected: &pb.Node{}, expected: &pb.Node{},
expectedErr: nil, expectedErr: nil,
}, },
{ {target: func() *node.ID {
k: k, id, err := newTestIdentity()
target: func() *node.ID { assert.NoError(t, err)
id, err := newTestIdentity() n := node.ID(id.ID)
assert.NoError(t, err) return &n
n := node.ID(id.ID) }(),
return &n
}(),
opts: lookupOpts{amount: 5}, opts: lookupOpts{amount: 5},
expected: nil, expected: nil,
expectedErr: nil, expectedErr: nil,
@ -150,23 +158,25 @@ func TestLookup(t *testing.T) {
} }
for _, v := range cases { for _, v := range cases {
err := v.k.lookup(context.Background(), v.target, v.opts) err := k.lookup(context.Background(), v.target, v.opts)
assert.Equal(t, v.expectedErr, err) assert.Equal(t, v.expectedErr, err)
} }
} }
func TestBootstrap(t *testing.T) { func TestBootstrap(t *testing.T) {
bn, s := testNode(t, []pb.Node{}) bn, s, clean := testNode(t, []pb.Node{})
defer clean()
defer s.Stop() defer s.Stop()
n1, s1 := testNode(t, []pb.Node{*bn.routingTable.self}) n1, s1, clean1 := testNode(t, []pb.Node{*bn.routingTable.self})
defer clean1()
defer s1.Stop() defer s1.Stop()
err := n1.Bootstrap(context.Background()) err := n1.Bootstrap(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
n2, s2 := testNode(t, []pb.Node{*bn.routingTable.self}) n2, s2, clean2 := testNode(t, []pb.Node{*bn.routingTable.self})
defer clean2()
defer s2.Stop() defer s2.Stop()
err = n2.Bootstrap(context.Background()) err = n2.Bootstrap(context.Background())
@ -175,10 +185,9 @@ func TestBootstrap(t *testing.T) {
nodeIDs, err := n2.routingTable.nodeBucketDB.List(nil, 0) nodeIDs, err := n2.routingTable.nodeBucketDB.List(nil, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, nodeIDs, 3) assert.Len(t, nodeIDs, 3)
} }
func testNode(t *testing.T, bn []pb.Node) (*Kademlia, *grpc.Server) { func testNode(t *testing.T, bn []pb.Node) (*Kademlia, *grpc.Server, func()) {
// new address // new address
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err) assert.NoError(t, err)
@ -189,7 +198,9 @@ func testNode(t *testing.T, bn []pb.Node) (*Kademlia, *grpc.Server) {
id := dht.NodeID(fid.ID) id := dht.NodeID(fid.ID)
assert.NoError(t, err) assert.NoError(t, err)
// new kademlia // new kademlia
k, err := NewKademlia(id, bn, lis.Addr().String(), fid, "db", kc) dir, cleanup := mktempdir(t, "kademlia")
k, err := NewKademlia(id, bn, lis.Addr().String(), fid, dir, kc)
assert.NoError(t, err) assert.NoError(t, err)
s := node.NewServer(k) s := node.NewServer(k)
// new ident opts // new ident opts
@ -199,9 +210,12 @@ func testNode(t *testing.T, bn []pb.Node) (*Kademlia, *grpc.Server) {
grpcServer := grpc.NewServer(identOpt) grpcServer := grpc.NewServer(identOpt)
pb.RegisterNodesServer(grpcServer, s) pb.RegisterNodesServer(grpcServer, s)
go func() { _ = grpcServer.Serve(lis) }() go func() { assert.NoError(t, grpcServer.Serve(lis)) }()
return k, grpcServer return k, grpcServer, func() {
defer cleanup()
assert.NoError(t, k.Disconnect())
}
} }
@ -212,7 +226,7 @@ func TestGetNodes(t *testing.T) {
kc := kadconfig() kc := kadconfig()
srv, _ := newTestServer([]*pb.Node{&pb.Node{Id: "foo"}}) srv, _ := newTestServer([]*pb.Node{&pb.Node{Id: "foo"}})
go func() { _ = srv.Serve(lis) }() go func() { assert.NoError(t, srv.Serve(lis)) }()
defer srv.Stop() defer srv.Stop()
// make new identity // make new identity
@ -227,9 +241,15 @@ func TestGetNodes(t *testing.T) {
id2 := node.ID(fid2.ID) id2 := node.ID(fid2.ID)
assert.NotEqual(t, id, id2) assert.NotEqual(t, id, id2)
kid := dht.NodeID(fid.ID) kid := dht.NodeID(fid.ID)
k, err := NewKademlia(kid, []pb.Node{pb.Node{Id: id2.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}}}, lis.Addr().String(), fid, "db", kc)
dir, cleanup := mktempdir(t, "kademlia")
defer cleanup()
k, err := NewKademlia(kid, []pb.Node{pb.Node{Id: id2.String(), Address: &pb.NodeAddress{Address: lis.Addr().String()}}}, lis.Addr().String(), fid, dir, kc)
assert.NoError(t, err) assert.NoError(t, err)
defer func() {
assert.NoError(t, k.Disconnect())
}()
// add nodes // add nodes
ids := []string{"AAAAA", "BBBBB", "CCCCC", "DDDDD"} ids := []string{"AAAAA", "BBBBB", "CCCCC", "DDDDD"}
bw := []int64{1, 2, 3, 4} bw := []int64{1, 2, 3, 4}
@ -301,7 +321,6 @@ func TestGetNodes(t *testing.T) {
} }
}) })
} }
} }
func TestMeetsRestrictions(t *testing.T) { func TestMeetsRestrictions(t *testing.T) {
@ -397,3 +416,12 @@ func TestMeetsRestrictions(t *testing.T) {
}) })
} }
} }
func mktempdir(t *testing.T, dir string) (string, func()) {
rootdir, err := ioutil.TempDir("", dir)
assert.NoError(t, err)
cleanup := func() {
assert.NoError(t, os.RemoveAll(rootdir))
}
return rootdir, cleanup
}

View File

@ -19,15 +19,18 @@ var NodeClientErr = errs.Class("node client error")
// NewNodeClient instantiates a node client // NewNodeClient instantiates a node client
func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (Client, error) { func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (Client, error) {
client := transport.NewClient(identity) client := transport.NewClient(identity)
return &Node{ node := &Node{
dht: dht, dht: dht,
self: self, self: self,
tc: client, tc: client,
cache: NewConnectionPool(), cache: NewConnectionPool(),
}, nil }
node.cache.Init()
return node, nil
} }
// Client is the Node client communication interface // Client is the Node client communication interface
type Client interface { type Client interface {
Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error)
Disconnect() error
} }

View File

@ -5,9 +5,16 @@ package node
import ( import (
"sync" "sync"
"github.com/zeebo/errs"
"storj.io/storj/pkg/utils"
) )
// ConnectionPool is the in memory implementation of a connection Pool // Error defines a connection pool error
var Error = errs.Class("connection pool error")
// ConnectionPool is the in memory pool of node connections
type ConnectionPool struct { type ConnectionPool struct {
mu sync.RWMutex mu sync.RWMutex
cache map[string]interface{} cache map[string]interface{}
@ -15,10 +22,7 @@ type ConnectionPool struct {
// NewConnectionPool initializes a new in memory pool // NewConnectionPool initializes a new in memory pool
func NewConnectionPool() *ConnectionPool { func NewConnectionPool() *ConnectionPool {
return &ConnectionPool{ return &ConnectionPool{}
cache: make(map[string]interface{}),
mu: sync.RWMutex{},
}
} }
// Add takes a node ID as the key and a node client as the value to store // Add takes a node ID as the key and a node client as the value to store
@ -44,3 +48,33 @@ func (pool *ConnectionPool) Remove(key string) error {
pool.cache[key] = nil pool.cache[key] = nil
return nil return nil
} }
// Disconnect closes the connection to the node and removes it from the connection pool
func (pool *ConnectionPool) Disconnect() error {
var err error
var errs []error
for k, v := range pool.cache {
conn, ok := v.(interface{ Close() error })
if !ok {
err = Error.New("connection pool value not a grpc client connection")
errs = append(errs, err)
continue
}
err = conn.Close()
if err != nil {
errs = append(errs, Error.Wrap(err))
continue
}
err = pool.Remove(k)
if err != nil {
errs = append(errs, Error.Wrap(err))
continue
}
}
return utils.CombineErrors(errs...)
}
// Init initializes the cache
func (pool *ConnectionPool) Init() {
pool.cache = make(map[string]interface{})
}

View File

@ -24,6 +24,7 @@ func TestGet(t *testing.T) {
{ {
pool: func() *ConnectionPool { pool: func() *ConnectionPool {
p := NewConnectionPool() p := NewConnectionPool()
p.Init()
assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"})) assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"}))
return p return p
}(), }(),

View File

@ -61,3 +61,8 @@ func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node
return resp.Response, nil return resp.Response, nil
} }
// Disconnect closes connections within the cache
func (n *Node) Disconnect() error {
return n.cache.Disconnect()
}