storj/pkg/node/connection_pool_test.go
Bryan White 2a0c4e60d2
preparing for use of customtype gogo extension with NodeID type (#693)
* preparing for use of `customtype` gogo extension with `NodeID` type

* review changes

* preparing for use of `customtype` gogo extension with `NodeID` type

* review changes

* wip

* tests passing

* wip fixing tests

* more wip test fixing

* remove NodeIDList from proto files

* linter fixes

* linter fixes

* linter/review fixes

* more freaking linter fixes

* omg just kill me - linterrrrrrrr

* travis linter, i will muder you and your family in your sleep

* goimports everything - burn in hell travis

* goimports update

* go mod tidy
2018-11-29 19:39:27 +01:00

119 lines
2.5 KiB
Go

// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package node
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"storj.io/storj/internal/teststorj"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
)
var fooID = teststorj.NodeIDFromString("foo")
func TestGet(t *testing.T) {
cases := []struct {
pool *ConnectionPool
nodeID storj.NodeID
expected Conn
expectedError error
}{
{
pool: func() *ConnectionPool {
p := NewConnectionPool(newTestIdentity(t))
p.Init()
p.items[fooID] = &Conn{addr: "foo"}
return p
}(),
nodeID: fooID,
expected: Conn{addr: "foo"},
expectedError: nil,
},
}
for i := range cases {
v := &cases[i]
test, err := v.pool.Get(v.nodeID)
assert.Equal(t, v.expectedError, err)
assert.Equal(t, v.expected.addr, test.(*Conn).addr)
}
}
func TestDisconnect(t *testing.T) {
conn, err := grpc.Dial("127.0.0.1:0", grpc.WithInsecure())
assert.NoError(t, err)
// gc.Close = func() error { return nil }
cases := []struct {
pool ConnectionPool
nodeID storj.NodeID
expected interface{}
expectedError error
}{
{
pool: ConnectionPool{
mu: sync.RWMutex{},
items: map[storj.NodeID]*Conn{fooID: &Conn{grpc: conn}},
},
nodeID: fooID,
expected: nil,
expectedError: nil,
},
}
for i := range cases {
v := &cases[i]
err := v.pool.Disconnect(v.nodeID)
assert.Equal(t, v.expectedError, err)
test, err := v.pool.Get(v.nodeID)
assert.Equal(t, v.expectedError, err)
assert.Equal(t, v.expected, test)
}
}
func TestDial(t *testing.T) {
t.Skip()
cases := []struct {
pool *ConnectionPool
node *pb.Node
expectedError error
expected *Conn
}{
{
pool: NewConnectionPool(newTestIdentity(t)),
node: &pb.Node{Id: fooID, Address: &pb.NodeAddress{Address: "127.0.0.1:0"}},
expected: nil,
expectedError: nil,
},
}
for _, v := range cases {
wg := sync.WaitGroup{}
wg.Add(4)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
go testDial(t, &wg, v.pool, v.node, v.expectedError)
wg.Wait()
}
}
func testDial(t *testing.T, wg *sync.WaitGroup, p *ConnectionPool, n *pb.Node, eerr error) {
defer wg.Done()
ctx := context.Background()
actual, err := p.Dial(ctx, n)
assert.Equal(t, eerr, err)
assert.NotNil(t, actual)
}