adds slow kad dialer tests, adds timeout interceptor to transport (#1545)

This commit is contained in:
Natalie Villasana 2019-03-22 13:09:37 -04:00 committed by GitHub
parent 9236ac4bdf
commit ea4a61f0c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 231 additions and 11 deletions

View File

@ -4,18 +4,24 @@
package kademlia_test package kademlia_test
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/internal/memory"
"storj.io/storj/internal/testcontext" "storj.io/storj/internal/testcontext"
"storj.io/storj/internal/testplanet" "storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/kademlia" "storj.io/storj/pkg/kademlia"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj" "storj.io/storj/pkg/storj"
"storj.io/storj/pkg/transport"
) )
func TestDialer(t *testing.T) { func TestDialer(t *testing.T) {
@ -152,6 +158,127 @@ func TestDialer(t *testing.T) {
}) })
} }
func TestSlowDialerHasTimeout(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 0,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
// TODO: also use satellites
peers := planet.StorageNodes
{ // PingNode
self := planet.StorageNodes[0]
tlsOpts, err := tlsopts.NewOptions(self.Identity, tlsopts.Config{})
require.NoError(t, err)
self.Transport = transport.NewClientWithTimeout(tlsOpts, 20*time.Millisecond)
network := &transport.SimulatedNetwork{
DialLatency: 200 * time.Second,
BytesPerSecond: 1 * memory.KB,
}
slowClient := network.NewClient(self.Transport)
require.NotNil(t, slowClient)
dialer := kademlia.NewDialer(zaptest.NewLogger(t), slowClient)
defer ctx.Check(dialer.Close)
var group errgroup.Group
defer ctx.Check(group.Wait)
for _, peer := range peers {
peer := peer
group.Go(func() error {
_, err := dialer.PingNode(ctx, peer.Local())
require.Error(t, err, context.DeadlineExceeded)
require.True(t, transport.Error.Has(err))
return nil
})
}
}
{ // FetchPeerIdentity
self := planet.StorageNodes[1]
tlsOpts, err := tlsopts.NewOptions(self.Identity, tlsopts.Config{})
require.NoError(t, err)
self.Transport = transport.NewClientWithTimeout(tlsOpts, 20*time.Millisecond)
network := &transport.SimulatedNetwork{
DialLatency: 200 * time.Second,
BytesPerSecond: 1 * memory.KB,
}
slowClient := network.NewClient(self.Transport)
require.NotNil(t, slowClient)
dialer := kademlia.NewDialer(zaptest.NewLogger(t), slowClient)
defer ctx.Check(dialer.Close)
var group errgroup.Group
defer ctx.Check(group.Wait)
group.Go(func() error {
_, err := dialer.FetchPeerIdentity(ctx, planet.Satellites[0].Local())
require.Error(t, err, context.DeadlineExceeded)
require.True(t, transport.Error.Has(err))
_, err = dialer.FetchPeerIdentityUnverified(ctx, planet.Satellites[0].Addr())
require.Error(t, err, context.DeadlineExceeded)
require.True(t, transport.Error.Has(err))
return nil
})
}
{ // Lookup
self := planet.StorageNodes[2]
tlsOpts, err := tlsopts.NewOptions(self.Identity, tlsopts.Config{})
require.NoError(t, err)
self.Transport = transport.NewClientWithTimeout(tlsOpts, 20*time.Millisecond)
network := &transport.SimulatedNetwork{
DialLatency: 200 * time.Second,
BytesPerSecond: 1 * memory.KB,
}
slowClient := network.NewClient(self.Transport)
require.NotNil(t, slowClient)
dialer := kademlia.NewDialer(zaptest.NewLogger(t), slowClient)
defer ctx.Check(dialer.Close)
var group errgroup.Group
defer ctx.Check(group.Wait)
for _, peer := range peers {
peer := peer
group.Go(func() error {
for _, target := range peers {
errTag := fmt.Errorf("lookup peer:%s target:%s", peer.ID(), target.ID())
peer.Local().Type.DPanicOnInvalid("test client peer")
target.Local().Type.DPanicOnInvalid("test client target")
_, err := dialer.Lookup(ctx, self.Local(), peer.Local(), target.Local())
require.Error(t, err, context.DeadlineExceeded, errTag)
require.True(t, transport.Error.Has(err), errTag)
return nil
}
return nil
})
}
}
})
}
func containsResult(nodes []*pb.Node, target storj.NodeID) bool { func containsResult(nodes []*pb.Node, target storj.NodeID) bool {
for _, node := range nodes { for _, node := range nodes {
if node.Id == target { if node.Id == target {

View File

@ -4,12 +4,20 @@
package kademlia_test package kademlia_test
import ( import (
"context"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"storj.io/storj/internal/memory"
"storj.io/storj/internal/testcontext" "storj.io/storj/internal/testcontext"
"storj.io/storj/internal/testplanet" "storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/kademlia"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/transport"
) )
func TestFetchPeerIdentity(t *testing.T) { func TestFetchPeerIdentity(t *testing.T) {
@ -39,3 +47,49 @@ func TestRequestInfo(t *testing.T) {
require.Equal(t, node.Local().Restrictions.GetFreeBandwidth(), info.GetCapacity().GetFreeBandwidth()) require.Equal(t, node.Local().Restrictions.GetFreeBandwidth(), info.GetCapacity().GetFreeBandwidth())
}) })
} }
func TestPingTimeout(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 0,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
self := planet.StorageNodes[0]
routingTable := self.Kademlia.RoutingTable
tlsOpts, err := tlsopts.NewOptions(self.Identity, tlsopts.Config{})
require.NoError(t, err)
self.Transport = transport.NewClientWithTimeout(tlsOpts, 1*time.Millisecond)
network := &transport.SimulatedNetwork{
DialLatency: 300 * time.Second,
BytesPerSecond: 1 * memory.KB,
}
slowClient := network.NewClient(self.Transport)
require.NotNil(t, slowClient)
node := pb.Node{
Id: self.ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
},
}
newService, err := kademlia.NewService(zaptest.NewLogger(t), node, slowClient, routingTable, kademlia.Config{})
require.NoError(t, err)
target := pb.Node{
Id: planet.StorageNodes[2].ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[2].Addr(),
},
}
_, err = newService.Ping(ctx, target)
require.Error(t, err, context.DeadlineExceeded)
require.True(t, kademlia.NodeErr.Has(err) && transport.Error.Has(err))
})
}

View File

@ -14,6 +14,11 @@ var (
mon = monkit.Package() mon = monkit.Package()
//Error is the errs class of standard Transport Client errors //Error is the errs class of standard Transport Client errors
Error = errs.Class("transport error") Error = errs.Class("transport error")
// default time to wait for a connection to be established )
connWaitTimeout = 20 * time.Second
const (
// default time to wait for a connection to be established
defaultDialTimeout = 20 * time.Second
// default time to wait for a response
defaultRequestTimeout = 20 * time.Second
) )

View File

@ -22,7 +22,7 @@ func DialAddressInsecure(ctx context.Context, address string, opts ...grpc.DialO
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
}, opts...) }, opts...)
timedCtx, cf := context.WithTimeout(ctx, connWaitTimeout) timedCtx, cf := context.WithTimeout(ctx, defaultDialTimeout)
defer cf() defer cf()
conn, err = grpc.DialContext(timedCtx, address, options...) conn, err = grpc.DialContext(timedCtx, address, options...)

24
pkg/transport/timeout.go Normal file
View File

@ -0,0 +1,24 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package transport
import (
"context"
"time"
"google.golang.org/grpc"
)
// InvokeTimeout enables timeouts for requests that take too long
type InvokeTimeout struct {
Timeout time.Duration
}
// Intercept adds a context timeout to a method call
func (it InvokeTimeout) Intercept(ctx context.Context, method string, req interface{}, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
timedCtx, cancel := context.WithTimeout(ctx, it.Timeout)
defer cancel()
return invoker(timedCtx, method, req, reply, cc, opts...)
}

View File

@ -5,6 +5,7 @@ package transport
import ( import (
"context" "context"
"time"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -30,15 +31,22 @@ type Client interface {
// Transport interface structure // Transport interface structure
type Transport struct { type Transport struct {
tlsOpts *tlsopts.Options tlsOpts *tlsopts.Options
observers []Observer observers []Observer
requestTimeout time.Duration
} }
// NewClient returns a newly instantiated Transport Client // NewClient returns a transport client with a default timeout for requests
func NewClient(tlsOpts *tlsopts.Options, obs ...Observer) Client { func NewClient(tlsOpts *tlsopts.Options, obs ...Observer) Client {
return NewClientWithTimeout(tlsOpts, defaultRequestTimeout, obs...)
}
// NewClientWithTimeout returns a transport client with a specified timeout for requests
func NewClientWithTimeout(tlsOpts *tlsopts.Options, requestTimeout time.Duration, obs ...Observer) Client {
return &Transport{ return &Transport{
tlsOpts: tlsOpts, tlsOpts: tlsOpts,
observers: obs, requestTimeout: requestTimeout,
observers: obs,
} }
} }
@ -65,9 +73,10 @@ func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ..
dialOption, dialOption,
grpc.WithBlock(), grpc.WithBlock(),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithUnaryInterceptor(InvokeTimeout{transport.requestTimeout}.Intercept),
}, opts...) }, opts...)
timedCtx, cancel := context.WithTimeout(ctx, connWaitTimeout) timedCtx, cancel := context.WithTimeout(ctx, defaultDialTimeout)
defer cancel() defer cancel()
conn, err = grpc.DialContext(timedCtx, node.GetAddress().Address, options...) conn, err = grpc.DialContext(timedCtx, node.GetAddress().Address, options...)
@ -96,9 +105,10 @@ func (transport *Transport) DialAddress(ctx context.Context, address string, opt
transport.tlsOpts.DialUnverifiedIDOption(), transport.tlsOpts.DialUnverifiedIDOption(),
grpc.WithBlock(), grpc.WithBlock(),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithUnaryInterceptor(InvokeTimeout{transport.requestTimeout}.Intercept),
}, opts...) }, opts...)
timedCtx, cancel := context.WithTimeout(ctx, connWaitTimeout) timedCtx, cancel := context.WithTimeout(ctx, defaultDialTimeout)
defer cancel() defer cancel()
conn, err = grpc.DialContext(timedCtx, address, options...) conn, err = grpc.DialContext(timedCtx, address, options...)
@ -115,7 +125,7 @@ func (transport *Transport) Identity() *identity.FullIdentity {
// WithObservers returns a new transport including the listed observers. // WithObservers returns a new transport including the listed observers.
func (transport *Transport) WithObservers(obs ...Observer) *Transport { func (transport *Transport) WithObservers(obs ...Observer) *Transport {
tr := &Transport{tlsOpts: transport.tlsOpts} tr := &Transport{tlsOpts: transport.tlsOpts, requestTimeout: transport.requestTimeout}
tr.observers = append(tr.observers, transport.observers...) tr.observers = append(tr.observers, transport.observers...)
tr.observers = append(tr.observers, obs...) tr.observers = append(tr.observers, obs...)
return tr return tr