parent
b3c1ea1852
commit
cb454638d9
@ -7,10 +7,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
|
|
||||||
"storj.io/storj/pkg/dht"
|
"storj.io/storj/pkg/dht"
|
||||||
"storj.io/storj/pkg/pb"
|
"storj.io/storj/pkg/pb"
|
||||||
"storj.io/storj/pkg/pool"
|
|
||||||
"storj.io/storj/pkg/provider"
|
"storj.io/storj/pkg/provider"
|
||||||
"storj.io/storj/pkg/transport"
|
"storj.io/storj/pkg/transport"
|
||||||
)
|
)
|
||||||
@ -25,7 +23,7 @@ func NewNodeClient(identity *provider.FullIdentity, self pb.Node, dht dht.DHT) (
|
|||||||
dht: dht,
|
dht: dht,
|
||||||
self: self,
|
self: self,
|
||||||
tc: client,
|
tc: client,
|
||||||
cache: pool.NewConnectionPool(),
|
cache: NewConnectionPool(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (C) 2018 Storj Labs, Inc.
|
// Copyright (C) 2018 Storj Labs, Inc.
|
||||||
// See LICENSE for copying information
|
// See LICENSE for copying information
|
||||||
|
|
||||||
package pool
|
package node
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,28 +22,25 @@ func NewConnectionPool() *ConnectionPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add takes a node ID as the key and a node client as the value to store
|
// Add takes a node ID as the key and a node client as the value to store
|
||||||
func (mp *ConnectionPool) Add(ctx context.Context, key string, value interface{}) error {
|
func (pool *ConnectionPool) Add(key string, value interface{}) error {
|
||||||
mp.mu.Lock()
|
pool.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer pool.mu.Unlock()
|
||||||
mp.cache[key] = value
|
pool.cache[key] = value
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a node connection with the provided nodeID
|
// Get retrieves a node connection with the provided nodeID
|
||||||
// nil is returned if the NodeID is not in the connection pool
|
// nil is returned if the NodeID is not in the connection pool
|
||||||
func (mp *ConnectionPool) Get(ctx context.Context, key string) (interface{}, error) {
|
func (pool *ConnectionPool) Get(key string) (interface{}, error) {
|
||||||
mp.mu.Lock()
|
pool.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer pool.mu.Unlock()
|
||||||
|
return pool.cache[key], nil
|
||||||
return mp.cache[key], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove deletes a connection associated with the provided NodeID
|
// Remove deletes a connection associated with the provided NodeID
|
||||||
func (mp *ConnectionPool) Remove(ctx context.Context, key string) error {
|
func (pool *ConnectionPool) Remove(key string) error {
|
||||||
mp.mu.Lock()
|
pool.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer pool.mu.Unlock()
|
||||||
mp.cache[key] = nil
|
pool.cache[key] = nil
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (C) 2018 Storj Labs, Inc.
|
// Copyright (C) 2018 Storj Labs, Inc.
|
||||||
// See LICENSE for copying information.
|
// See LICENSE for copying information.
|
||||||
|
|
||||||
package pool
|
package node
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -16,7 +15,6 @@ type TestFoo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGet(t *testing.T) {
|
func TestGet(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
pool *ConnectionPool
|
pool *ConnectionPool
|
||||||
key string
|
key string
|
||||||
@ -26,7 +24,7 @@ func TestGet(t *testing.T) {
|
|||||||
{
|
{
|
||||||
pool: func() *ConnectionPool {
|
pool: func() *ConnectionPool {
|
||||||
p := NewConnectionPool()
|
p := NewConnectionPool()
|
||||||
assert.NoError(t, p.Add(ctx, "foo", TestFoo{called: "hoot"}))
|
assert.NoError(t, p.Add("foo", TestFoo{called: "hoot"}))
|
||||||
return p
|
return p
|
||||||
}(),
|
}(),
|
||||||
key: "foo",
|
key: "foo",
|
||||||
@ -37,7 +35,7 @@ func TestGet(t *testing.T) {
|
|||||||
|
|
||||||
for i := range cases {
|
for i := range cases {
|
||||||
v := &cases[i]
|
v := &cases[i]
|
||||||
test, err := v.pool.Get(ctx, v.key)
|
test, err := v.pool.Get(v.key)
|
||||||
assert.Equal(t, v.expectedError, err)
|
assert.Equal(t, v.expectedError, err)
|
||||||
assert.Equal(t, v.expected, test)
|
assert.Equal(t, v.expected, test)
|
||||||
}
|
}
|
||||||
@ -64,10 +62,10 @@ func TestAdd(t *testing.T) {
|
|||||||
|
|
||||||
for i := range cases {
|
for i := range cases {
|
||||||
v := &cases[i]
|
v := &cases[i]
|
||||||
err := v.pool.Add(context.Background(), v.key, v.value)
|
err := v.pool.Add(v.key, v.value)
|
||||||
assert.Equal(t, v.expectedError, err)
|
assert.Equal(t, v.expectedError, err)
|
||||||
|
|
||||||
test, err := v.pool.Get(context.Background(), v.key)
|
test, err := v.pool.Get(v.key)
|
||||||
assert.Equal(t, v.expectedError, err)
|
assert.Equal(t, v.expectedError, err)
|
||||||
|
|
||||||
assert.Equal(t, v.expected, test)
|
assert.Equal(t, v.expected, test)
|
||||||
@ -93,10 +91,10 @@ func TestRemove(t *testing.T) {
|
|||||||
|
|
||||||
for i := range cases {
|
for i := range cases {
|
||||||
v := &cases[i]
|
v := &cases[i]
|
||||||
err := v.pool.Remove(context.Background(), v.key)
|
err := v.pool.Remove(v.key)
|
||||||
assert.Equal(t, v.expectedError, err)
|
assert.Equal(t, v.expectedError, err)
|
||||||
|
|
||||||
test, err := v.pool.Get(context.Background(), v.key)
|
test, err := v.pool.Get(v.key)
|
||||||
assert.Equal(t, v.expectedError, err)
|
assert.Equal(t, v.expectedError, err)
|
||||||
|
|
||||||
assert.Equal(t, v.expected, test)
|
assert.Equal(t, v.expected, test)
|
@ -8,10 +8,8 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"storj.io/storj/pkg/dht"
|
"storj.io/storj/pkg/dht"
|
||||||
"storj.io/storj/pkg/pb"
|
"storj.io/storj/pkg/pb"
|
||||||
"storj.io/storj/pkg/pool"
|
|
||||||
"storj.io/storj/pkg/transport"
|
"storj.io/storj/pkg/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,12 +18,12 @@ type Node struct {
|
|||||||
dht dht.DHT
|
dht dht.DHT
|
||||||
self pb.Node
|
self pb.Node
|
||||||
tc transport.Client
|
tc transport.Client
|
||||||
cache pool.Pool
|
cache *ConnectionPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup queries nodes looking for a particular node in the network
|
// Lookup queries nodes looking for a particular node in the network
|
||||||
func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) {
|
func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node, error) {
|
||||||
v, err := n.cache.Get(ctx, to.GetId())
|
v, err := n.cache.Get(to.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -39,7 +37,7 @@ func (n *Node) Lookup(ctx context.Context, to pb.Node, find pb.Node) ([]*pb.Node
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := n.cache.Add(ctx, to.GetId(), c); err != nil {
|
if err := n.cache.Add(to.GetId(), c); err != nil {
|
||||||
log.Printf("Error %s occurred adding %s to cache", err, to.GetId())
|
log.Printf("Error %s occurred adding %s to cache", err, to.GetId())
|
||||||
}
|
}
|
||||||
conn = c
|
conn = c
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
// Copyright (C) 2018 Storj Labs, Inc.
|
|
||||||
// See LICENSE for copying information
|
|
||||||
|
|
||||||
package pool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Pool is a set of actions for maintaining a node connection pool
|
|
||||||
type Pool interface {
|
|
||||||
Add(ctx context.Context, key string, value interface{}) error
|
|
||||||
Get(ctx context.Context, key string) (interface{}, error)
|
|
||||||
Remove(ctx context.Context, key string) error
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user