parent
b3c1ea1852
commit
cb454638d9
@ -7,10 +7,8 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/pkg/dht"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/pool"
|
||||
"storj.io/storj/pkg/provider"
|
||||
"storj.io/storj/pkg/transport"
|
||||
)
|
||||
@ -25,7 +23,7 @@ func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (
|
||||
dht: dht,
|
||||
self: self,
|
||||
tc: client,
|
||||
cache: pool.NewConnectionPool(),
|
||||
cache: NewConnectionPool(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
// Copyright (C) 2018 Storj Labs, Inc.
|
||||
// See LICENSE for copying information
|
||||
|
||||
package pool
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -23,28 +22,25 @@ func NewConnectionPool() *ConnectionPool {
|
||||
}
|
||||
|
||||
// Add takes a node ID as the key and a node client as the value to store
|
||||
func (mp *ConnectionPool) Add(ctx context.Context, key string, value interface{}) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
mp.cache[key] = value
|
||||
|
||||
func (pool *ConnectionPool) Add(key string, value interface{}) error {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
pool.cache[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a node connection with the provided nodeID
|
||||
// nil is returned if the NodeID is not in the connection pool
|
||||
func (mp *ConnectionPool) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
return mp.cache[key], nil
|
||||
func (pool *ConnectionPool) Get(key string) (interface{}, error) {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
return pool.cache[key], nil
|
||||
}
|
||||
|
||||
// Remove deletes a connection associated with the provided NodeID
|
||||
func (mp *ConnectionPool) Remove(ctx context.Context, key string) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
mp.cache[key] = nil
|
||||
|
||||
func (pool *ConnectionPool) Remove(key string) error {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
pool.cache[key] = nil
|
||||
return nil
|
||||
}
|
@ -1,10 +1,9 @@
|
||||
// Copyright (C) 2018 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package pool
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@ -16,7 +15,6 @@ type TestFoo struct {
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cases := []struct {
|
||||
pool *ConnectionPool
|
||||
key string
|
||||
@ -26,7 +24,7 @@ func TestGet(t *testing.T) {
|
||||
{
|
||||
pool: func() *ConnectionPool {
|
||||
p := NewConnectionPool()
|
||||
assert.NoError(t, p.Add(ctx, "foo", TestFoo{called: "hoot"}))
|
||||
assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"}))
|
||||
return p
|
||||
}(),
|
||||
key: "foo",
|
||||
@ -37,7 +35,7 @@ func TestGet(t *testing.T) {
|
||||
|
||||
for i := range cases {
|
||||
v := &cases[i]
|
||||
test, err := v.pool.Get(ctx, v.key)
|
||||
test, err := v.pool.Get(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
assert.Equal(t, v.expected, test)
|
||||
}
|
||||
@ -64,10 +62,10 @@ func TestAdd(t *testing.T) {
|
||||
|
||||
for i := range cases {
|
||||
v := &cases[i]
|
||||
err := v.pool.Add(context.Background(), v.key, v.value)
|
||||
err := v.pool.Add(v.key, v.value)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
test, err := v.pool.Get(context.Background(), v.key)
|
||||
test, err := v.pool.Get(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
assert.Equal(t, v.expected, test)
|
||||
@ -93,10 +91,10 @@ func TestRemove(t *testing.T) {
|
||||
|
||||
for i := range cases {
|
||||
v := &cases[i]
|
||||
err := v.pool.Remove(context.Background(), v.key)
|
||||
err := v.pool.Remove(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
test, err := v.pool.Get(context.Background(), v.key)
|
||||
test, err := v.pool.Get(v.key)
|
||||
assert.Equal(t, v.expectedError, err)
|
||||
|
||||
assert.Equal(t, v.expected, test)
|
@ -8,10 +8,8 @@ import (
|
||||
"log"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"storj.io/storj/pkg/dht"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/pool"
|
||||
"storj.io/storj/pkg/transport"
|
||||
)
|
||||
|
||||
@ -20,12 +18,12 @@ type Node struct {
|
||||
dht dht.DHT
|
||||
self pb.Node
|
||||
tc transport.Client
|
||||
cache pool.Pool
|
||||
cache *ConnectionPool
|
||||
}
|
||||
|
||||
// Lookup queries nodes looking for a particular node in the network
|
||||
func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) {
|
||||
v, err := n.cache.Get(ctx, to.GetId())
|
||||
v, err := n.cache.Get(to.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -39,7 +37,7 @@ func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.cache.Add(ctx, to.GetId(), c); err != nil {
|
||||
if err := n.cache.Add(to.GetId(), c); err != nil {
|
||||
log.Printf("Error %s occurred adding %s to cache", err, to.GetId())
|
||||
}
|
||||
conn = c
|
||||
|
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2018 Storj Labs, Inc.
|
||||
// See LICENSE for copying information
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Pool is a set of actions for maintaining a node connection pool
|
||||
type Pool interface {
|
||||
Add(ctx context.Context, key string, value interface{}) error
|
||||
Get(ctx context.Context, key string) (interface{}, error)
|
||||
Remove(ctx context.Context, key string) error
|
||||
}
|
Loading…
Reference in New Issue
Block a user