storj/pkg/transport/transport.go
JT Olio 2a59679766 pkg/transport: require tls configuration for dialing (#1286)
* separate TLS options from server options (because we need them for dialing too)
* stop creating transports in multiple places
* ensure that we actually check revocation, whitelists, certificate signing, etc, for all connections.
2019-02-11 13:17:32 +02:00

126 lines
3.6 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package transport
import (
"context"
"time"
"github.com/zeebo/errs"
"google.golang.org/grpc"
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/pkg/identity"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
)
var (
mon = monkit.Package()
//Error is the errs class of standard Transport Client errors
Error = errs.Class("transport error")
// default time to wait for a connection to be established
timeout = 20 * time.Second
)
// Observer implements the ConnSuccess and ConnFailure methods
// for Discovery and other services to use
type Observer interface {
ConnSuccess(ctx context.Context, node *pb.Node)
ConnFailure(ctx context.Context, node *pb.Node, err error)
}
// Client defines the interface to an transport client.
type Client interface {
DialNode(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (*grpc.ClientConn, error)
DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error)
Identity() *identity.FullIdentity
WithObservers(obs ...Observer) *Transport
}
// Transport interface structure
type Transport struct {
tlsOpts *tlsopts.Options
observers []Observer
}
// NewClient returns a newly instantiated Transport Client
func NewClient(tlsOpts *tlsopts.Options, obs ...Observer) Client {
return &Transport{
tlsOpts: tlsOpts,
observers: obs,
}
}
// DialNode returns a grpc connection with tls to a node
func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
defer mon.Task()(&ctx)(&err)
if node != nil {
node.Type.DPanicOnInvalid("transport dial node")
}
if node.Address == nil || node.Address.Address == "" {
return nil, Error.New("no address")
}
// add ID of node we are wanting to connect to
dialOpt := transport.tlsOpts.DialOption(node.Id)
options := append([]grpc.DialOption{dialOpt, grpc.WithBlock(), grpc.FailOnNonTempDialError(true)}, opts...)
ctx, cf := context.WithTimeout(ctx, timeout)
defer cf()
conn, err = grpc.DialContext(ctx, node.GetAddress().Address, options...)
if err != nil {
if err == context.Canceled {
return nil, err
}
alertFail(ctx, transport.observers, node, err)
return nil, Error.Wrap(err)
}
alertSuccess(ctx, transport.observers, node)
return conn, nil
}
// DialAddress returns a grpc connection with tls to an IP address
func (transport *Transport) DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
defer mon.Task()(&ctx)(&err)
dialOpt := transport.tlsOpts.DialOption(storj.NodeID{})
options := append([]grpc.DialOption{dialOpt, grpc.WithBlock(), grpc.FailOnNonTempDialError(true)}, opts...)
conn, err = grpc.DialContext(ctx, address, options...)
if err == context.Canceled {
return nil, err
}
return conn, Error.Wrap(err)
}
// Identity is a getter for the transport's identity
func (transport *Transport) Identity() *identity.FullIdentity {
return transport.tlsOpts.Ident
}
// WithObservers returns a new transport including the listed observers.
func (transport *Transport) WithObservers(obs ...Observer) *Transport {
tr := &Transport{tlsOpts: transport.tlsOpts}
tr.observers = append(tr.observers, transport.observers...)
tr.observers = append(tr.observers, obs...)
return tr
}
func alertFail(ctx context.Context, obs []Observer, node *pb.Node, err error) {
for _, o := range obs {
o.ConnFailure(ctx, node, err)
}
}
func alertSuccess(ctx context.Context, obs []Observer, node *pb.Node) {
for _, o := range obs {
o.ConnSuccess(ctx, node)
}
}