pkg/rpc: build tag based selection of rpc details

It provides an abstraction around the rpc details so that one
can use dprc or gprc with the same code. It subsumes using the
protobuf package directly for client interfaces as well as
the pkg/transport package to perform dials.

Change-Id: I8f5688bd71be8b0c766f13029128a77e5d46320b
This commit is contained in:
Jeff Wendling 2019-09-18 21:34:19 -06:00
parent 9ceff9f9c6
commit a20a7db793
9 changed files with 1001 additions and 0 deletions

View File

@ -48,6 +48,12 @@ func (opts *Options) ClientTLSConfig(id storj.NodeID) *tls.Config {
return opts.tlsConfig(false, verifyIdentity(id))
}
// UnverifiedClientTLSConfig returns a TLSConfig for use as a client in handshaking with
// an unknown peer.
func (opts *Options) UnverifiedClientTLSConfig() *tls.Config {
return opts.tlsConfig(false)
}
func (opts *Options) tlsConfig(isServer bool, verificationFuncs ...peertls.PeerCertVerificationFunc) *tls.Config {
verificationFuncs = append(
[]peertls.PeerCertVerificationFunc{

78
pkg/rpc/common.go Normal file
View File

@ -0,0 +1,78 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package rpc
import (
"net"
"time"
"github.com/zeebo/errs"
"gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/internal/memory"
)
//go:generate go run gen.go ../pb drpc compat_drpc.go
//go:generate go run gen.go ../pb grpc compat_grpc.go
var mon = monkit.Package()
// Error wraps all of the errors returned by this package.
var Error = errs.Class("rpccompat")
// timedConn wraps a net.Conn so that all reads and writes get the specified timeout and
// return bytes no faster than the rate. If the timeout or rate are zero, they are
// ignored.
type timedConn struct {
net.Conn
timeout time.Duration
rate memory.Size
}
// now returns time.Now if there's a nonzero rate.
func (t *timedConn) now() (now time.Time) {
if t.rate > 0 {
now = time.Now()
}
return now
}
// delay ensures that we sleep to keep the rate if it is nonzero. n is the number of
// bytes in the read or write operation we need to delay.
func (t *timedConn) delay(start time.Time, n int) {
if t.rate > 0 {
expected := time.Duration(n * int(time.Second) / t.rate.Int())
if actual := time.Since(start); expected > actual {
time.Sleep(expected - actual)
}
}
}
// Read wraps the connection read setting the timeout and sleeping to ensure the rate.
func (t *timedConn) Read(p []byte) (int, error) {
if t.timeout > 0 {
if err := t.SetReadDeadline(time.Now().Add(t.timeout)); err != nil {
return 0, err
}
}
start := t.now()
n, err := t.Conn.Read(p)
t.delay(start, n)
return n, err
}
// Write wraps the connection write setting the timeout and sleeping to ensure the rate.
func (t *timedConn) Write(p []byte) (int, error) {
if t.timeout > 0 {
if err := t.SetWriteDeadline(time.Now().Add(t.timeout)); err != nil {
return 0, err
}
}
start := t.now()
n, err := t.Conn.Write(p)
t.delay(start, n)
return n, err
}

198
pkg/rpc/compat_drpc.go Normal file
View File

@ -0,0 +1,198 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build drpc
package rpc
import (
"storj.io/drpc/drpcconn"
"storj.io/storj/pkg/pb"
)
// RawConn is a type alias to a drpc client connection
type RawConn = drpcconn.Conn
type (
// CertificatesClient is an alias to the drpc client interface
CertificatesClient = pb.DRPCCertificatesClient
// ContactClient is an alias to the drpc client interface
ContactClient = pb.DRPCContactClient
// HealthInspectorClient is an alias to the drpc client interface
HealthInspectorClient = pb.DRPCHealthInspectorClient
// IrreparableInspectorClient is an alias to the drpc client interface
IrreparableInspectorClient = pb.DRPCIrreparableInspectorClient
// KadInspectorClient is an alias to the drpc client interface
KadInspectorClient = pb.DRPCKadInspectorClient
// MetainfoClient is an alias to the drpc client interface
MetainfoClient = pb.DRPCMetainfoClient
// NodeClient is an alias to the drpc client interface
NodeClient = pb.DRPCNodeClient
// NodeStatsClient is an alias to the drpc client interface
NodeStatsClient = pb.DRPCNodeStatsClient
// NodesClient is an alias to the drpc client interface
NodesClient = pb.DRPCNodesClient
// OrdersClient is an alias to the drpc client interface
OrdersClient = pb.DRPCOrdersClient
// OverlayInspectorClient is an alias to the drpc client interface
OverlayInspectorClient = pb.DRPCOverlayInspectorClient
// PieceStoreInspectorClient is an alias to the drpc client interface
PieceStoreInspectorClient = pb.DRPCPieceStoreInspectorClient
// PiecestoreClient is an alias to the drpc client interface
PiecestoreClient = pb.DRPCPiecestoreClient
// VouchersClient is an alias to the drpc client interface
VouchersClient = pb.DRPCVouchersClient
)
// NewCertificatesClient returns the drpc version of a CertificatesClient
func NewCertificatesClient(rc *RawConn) CertificatesClient {
return pb.NewDRPCCertificatesClient(rc)
}
// CertificatesClient returns a CertificatesClient for this connection
func (c *Conn) CertificatesClient() CertificatesClient {
return NewCertificatesClient(c.raw)
}
// NewContactClient returns the drpc version of a ContactClient
func NewContactClient(rc *RawConn) ContactClient {
return pb.NewDRPCContactClient(rc)
}
// ContactClient returns a ContactClient for this connection
func (c *Conn) ContactClient() ContactClient {
return NewContactClient(c.raw)
}
// NewHealthInspectorClient returns the drpc version of a HealthInspectorClient
func NewHealthInspectorClient(rc *RawConn) HealthInspectorClient {
return pb.NewDRPCHealthInspectorClient(rc)
}
// HealthInspectorClient returns a HealthInspectorClient for this connection
func (c *Conn) HealthInspectorClient() HealthInspectorClient {
return NewHealthInspectorClient(c.raw)
}
// NewIrreparableInspectorClient returns the drpc version of a IrreparableInspectorClient
func NewIrreparableInspectorClient(rc *RawConn) IrreparableInspectorClient {
return pb.NewDRPCIrreparableInspectorClient(rc)
}
// IrreparableInspectorClient returns a IrreparableInspectorClient for this connection
func (c *Conn) IrreparableInspectorClient() IrreparableInspectorClient {
return NewIrreparableInspectorClient(c.raw)
}
// NewKadInspectorClient returns the drpc version of a KadInspectorClient
func NewKadInspectorClient(rc *RawConn) KadInspectorClient {
return pb.NewDRPCKadInspectorClient(rc)
}
// KadInspectorClient returns a KadInspectorClient for this connection
func (c *Conn) KadInspectorClient() KadInspectorClient {
return NewKadInspectorClient(c.raw)
}
// NewMetainfoClient returns the drpc version of a MetainfoClient
func NewMetainfoClient(rc *RawConn) MetainfoClient {
return pb.NewDRPCMetainfoClient(rc)
}
// MetainfoClient returns a MetainfoClient for this connection
func (c *Conn) MetainfoClient() MetainfoClient {
return NewMetainfoClient(c.raw)
}
// NewNodeClient returns the drpc version of a NodeClient
func NewNodeClient(rc *RawConn) NodeClient {
return pb.NewDRPCNodeClient(rc)
}
// NodeClient returns a NodeClient for this connection
func (c *Conn) NodeClient() NodeClient {
return NewNodeClient(c.raw)
}
// NewNodeStatsClient returns the drpc version of a NodeStatsClient
func NewNodeStatsClient(rc *RawConn) NodeStatsClient {
return pb.NewDRPCNodeStatsClient(rc)
}
// NodeStatsClient returns a NodeStatsClient for this connection
func (c *Conn) NodeStatsClient() NodeStatsClient {
return NewNodeStatsClient(c.raw)
}
// NewNodesClient returns the drpc version of a NodesClient
func NewNodesClient(rc *RawConn) NodesClient {
return pb.NewDRPCNodesClient(rc)
}
// NodesClient returns a NodesClient for this connection
func (c *Conn) NodesClient() NodesClient {
return NewNodesClient(c.raw)
}
// NewOrdersClient returns the drpc version of a OrdersClient
func NewOrdersClient(rc *RawConn) OrdersClient {
return pb.NewDRPCOrdersClient(rc)
}
// OrdersClient returns a OrdersClient for this connection
func (c *Conn) OrdersClient() OrdersClient {
return NewOrdersClient(c.raw)
}
// NewOverlayInspectorClient returns the drpc version of a OverlayInspectorClient
func NewOverlayInspectorClient(rc *RawConn) OverlayInspectorClient {
return pb.NewDRPCOverlayInspectorClient(rc)
}
// OverlayInspectorClient returns a OverlayInspectorClient for this connection
func (c *Conn) OverlayInspectorClient() OverlayInspectorClient {
return NewOverlayInspectorClient(c.raw)
}
// NewPieceStoreInspectorClient returns the drpc version of a PieceStoreInspectorClient
func NewPieceStoreInspectorClient(rc *RawConn) PieceStoreInspectorClient {
return pb.NewDRPCPieceStoreInspectorClient(rc)
}
// PieceStoreInspectorClient returns a PieceStoreInspectorClient for this connection
func (c *Conn) PieceStoreInspectorClient() PieceStoreInspectorClient {
return NewPieceStoreInspectorClient(c.raw)
}
// NewPiecestoreClient returns the drpc version of a PiecestoreClient
func NewPiecestoreClient(rc *RawConn) PiecestoreClient {
return pb.NewDRPCPiecestoreClient(rc)
}
// PiecestoreClient returns a PiecestoreClient for this connection
func (c *Conn) PiecestoreClient() PiecestoreClient {
return NewPiecestoreClient(c.raw)
}
// NewVouchersClient returns the drpc version of a VouchersClient
func NewVouchersClient(rc *RawConn) VouchersClient {
return pb.NewDRPCVouchersClient(rc)
}
// VouchersClient returns a VouchersClient for this connection
func (c *Conn) VouchersClient() VouchersClient {
return NewVouchersClient(c.raw)
}

199
pkg/rpc/compat_grpc.go Normal file
View File

@ -0,0 +1,199 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build !drpc
package rpc
import (
"google.golang.org/grpc"
"storj.io/storj/pkg/pb"
)
// RawConn is a type alias to a grpc client connection
type RawConn = grpc.ClientConn
type (
// CertificatesClient is an alias to the grpc client interface
CertificatesClient = pb.CertificatesClient
// ContactClient is an alias to the grpc client interface
ContactClient = pb.ContactClient
// HealthInspectorClient is an alias to the grpc client interface
HealthInspectorClient = pb.HealthInspectorClient
// IrreparableInspectorClient is an alias to the grpc client interface
IrreparableInspectorClient = pb.IrreparableInspectorClient
// KadInspectorClient is an alias to the grpc client interface
KadInspectorClient = pb.KadInspectorClient
// MetainfoClient is an alias to the grpc client interface
MetainfoClient = pb.MetainfoClient
// NodeClient is an alias to the grpc client interface
NodeClient = pb.NodeClient
// NodeStatsClient is an alias to the grpc client interface
NodeStatsClient = pb.NodeStatsClient
// NodesClient is an alias to the grpc client interface
NodesClient = pb.NodesClient
// OrdersClient is an alias to the grpc client interface
OrdersClient = pb.OrdersClient
// OverlayInspectorClient is an alias to the grpc client interface
OverlayInspectorClient = pb.OverlayInspectorClient
// PieceStoreInspectorClient is an alias to the grpc client interface
PieceStoreInspectorClient = pb.PieceStoreInspectorClient
// PiecestoreClient is an alias to the grpc client interface
PiecestoreClient = pb.PiecestoreClient
// VouchersClient is an alias to the grpc client interface
VouchersClient = pb.VouchersClient
)
// NewCertificatesClient returns the grpc version of a CertificatesClient
func NewCertificatesClient(rc *RawConn) CertificatesClient {
return pb.NewCertificatesClient(rc)
}
// CertificatesClient returns a CertificatesClient for this connection
func (c *Conn) CertificatesClient() CertificatesClient {
return NewCertificatesClient(c.raw)
}
// NewContactClient returns the grpc version of a ContactClient
func NewContactClient(rc *RawConn) ContactClient {
return pb.NewContactClient(rc)
}
// ContactClient returns a ContactClient for this connection
func (c *Conn) ContactClient() ContactClient {
return NewContactClient(c.raw)
}
// NewHealthInspectorClient returns the grpc version of a HealthInspectorClient
func NewHealthInspectorClient(rc *RawConn) HealthInspectorClient {
return pb.NewHealthInspectorClient(rc)
}
// HealthInspectorClient returns a HealthInspectorClient for this connection
func (c *Conn) HealthInspectorClient() HealthInspectorClient {
return NewHealthInspectorClient(c.raw)
}
// NewIrreparableInspectorClient returns the grpc version of a IrreparableInspectorClient
func NewIrreparableInspectorClient(rc *RawConn) IrreparableInspectorClient {
return pb.NewIrreparableInspectorClient(rc)
}
// IrreparableInspectorClient returns a IrreparableInspectorClient for this connection
func (c *Conn) IrreparableInspectorClient() IrreparableInspectorClient {
return NewIrreparableInspectorClient(c.raw)
}
// NewKadInspectorClient returns the grpc version of a KadInspectorClient
func NewKadInspectorClient(rc *RawConn) KadInspectorClient {
return pb.NewKadInspectorClient(rc)
}
// KadInspectorClient returns a KadInspectorClient for this connection
func (c *Conn) KadInspectorClient() KadInspectorClient {
return NewKadInspectorClient(c.raw)
}
// NewMetainfoClient returns the grpc version of a MetainfoClient
func NewMetainfoClient(rc *RawConn) MetainfoClient {
return pb.NewMetainfoClient(rc)
}
// MetainfoClient returns a MetainfoClient for this connection
func (c *Conn) MetainfoClient() MetainfoClient {
return NewMetainfoClient(c.raw)
}
// NewNodeClient returns the grpc version of a NodeClient
func NewNodeClient(rc *RawConn) NodeClient {
return pb.NewNodeClient(rc)
}
// NodeClient returns a NodeClient for this connection
func (c *Conn) NodeClient() NodeClient {
return NewNodeClient(c.raw)
}
// NewNodeStatsClient returns the grpc version of a NodeStatsClient
func NewNodeStatsClient(rc *RawConn) NodeStatsClient {
return pb.NewNodeStatsClient(rc)
}
// NodeStatsClient returns a NodeStatsClient for this connection
func (c *Conn) NodeStatsClient() NodeStatsClient {
return NewNodeStatsClient(c.raw)
}
// NewNodesClient returns the grpc version of a NodesClient
func NewNodesClient(rc *RawConn) NodesClient {
return pb.NewNodesClient(rc)
}
// NodesClient returns a NodesClient for this connection
func (c *Conn) NodesClient() NodesClient {
return NewNodesClient(c.raw)
}
// NewOrdersClient returns the grpc version of a OrdersClient
func NewOrdersClient(rc *RawConn) OrdersClient {
return pb.NewOrdersClient(rc)
}
// OrdersClient returns a OrdersClient for this connection
func (c *Conn) OrdersClient() OrdersClient {
return NewOrdersClient(c.raw)
}
// NewOverlayInspectorClient returns the grpc version of a OverlayInspectorClient
func NewOverlayInspectorClient(rc *RawConn) OverlayInspectorClient {
return pb.NewOverlayInspectorClient(rc)
}
// OverlayInspectorClient returns a OverlayInspectorClient for this connection
func (c *Conn) OverlayInspectorClient() OverlayInspectorClient {
return NewOverlayInspectorClient(c.raw)
}
// NewPieceStoreInspectorClient returns the grpc version of a PieceStoreInspectorClient
func NewPieceStoreInspectorClient(rc *RawConn) PieceStoreInspectorClient {
return pb.NewPieceStoreInspectorClient(rc)
}
// PieceStoreInspectorClient returns a PieceStoreInspectorClient for this connection
func (c *Conn) PieceStoreInspectorClient() PieceStoreInspectorClient {
return NewPieceStoreInspectorClient(c.raw)
}
// NewPiecestoreClient returns the grpc version of a PiecestoreClient
func NewPiecestoreClient(rc *RawConn) PiecestoreClient {
return pb.NewPiecestoreClient(rc)
}
// PiecestoreClient returns a PiecestoreClient for this connection
func (c *Conn) PiecestoreClient() PiecestoreClient {
return NewPiecestoreClient(c.raw)
}
// NewVouchersClient returns the grpc version of a VouchersClient
func NewVouchersClient(rc *RawConn) VouchersClient {
return pb.NewVouchersClient(rc)
}
// VouchersClient returns a VouchersClient for this connection
func (c *Conn) VouchersClient() VouchersClient {
return NewVouchersClient(c.raw)
}

30
pkg/rpc/conn.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package rpc
import (
"crypto/tls"
"storj.io/storj/pkg/identity"
)
// Conn is a wrapper around a drpc client connection.
type Conn struct {
raw *RawConn
state tls.ConnectionState
}
// Close closes the connection.
func (c *Conn) Close() error { return c.raw.Close() }
// RawConn returns the underlying connection.
func (c *Conn) RawConn() *RawConn { return c.raw }
// ConnectionState returns the tls connection state.
func (c *Conn) ConnectionState() tls.ConnectionState { return c.state }
// PeerIdentity returns the peer identity on the other end of the connection.
func (c *Conn) PeerIdentity() (*identity.PeerIdentity, error) {
return identity.PeerIdentityFromChain(c.state.PeerCertificates)
}

118
pkg/rpc/dial.go Normal file
View File

@ -0,0 +1,118 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package rpc
import (
"context"
"net"
"time"
"storj.io/storj/internal/memory"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
)
// Dialer holds configuration for dialing.
type Dialer struct {
// TLSOptions controls the tls options for dialing. If it is nil, only
// insecure connections can be made.
TLSOptions *tlsopts.Options
// RequestTimeout causes any read/write operations on the raw socket
// to error if they take longer than it if it is non-zero.
RequestTimeout time.Duration
// DialTimeout causes all the tcp dials to error if they take longer
// than it if it is non-zero.
DialTimeout time.Duration
// DialLatency sleeps this amount if it is non-zero before every dial.
// The timeout runs while the sleep is happening.
DialLatency time.Duration
// TransferRate limits all read/write operations to go slower than
// the size per second if it is non-zero.
TransferRate memory.Size
}
// NewDefaultDialer returns a Dialer with default timeouts set.
func NewDefaultDialer(tlsOptions *tlsopts.Options) Dialer {
return Dialer{
TLSOptions: tlsOptions,
RequestTimeout: 10 * time.Minute,
DialTimeout: 20 * time.Second,
}
}
// dialContext does a raw tcp dial to the address and wraps the connection with the
// provided timeout.
func (d Dialer) dialContext(ctx context.Context, address string) (net.Conn, error) {
if d.DialTimeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
defer cancel()
}
if d.DialLatency > 0 {
timer := time.NewTimer(d.DialLatency)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return nil, Error.Wrap(ctx.Err())
}
}
conn, err := new(net.Dialer).DialContext(ctx, "tcp", address)
if err != nil {
return nil, Error.Wrap(err)
}
return &timedConn{
Conn: conn,
timeout: d.RequestTimeout,
rate: d.TransferRate,
}, nil
}
// DialNode creates an rpc connection to the specified node.
func (d Dialer) DialNode(ctx context.Context, node *pb.Node) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.TLSOptions == nil {
return nil, Error.New("tls options not set when required for this dial")
}
return d.dial(ctx, node.GetAddress().GetAddress(), d.TLSOptions.ClientTLSConfig(node.Id))
}
// DialAddressID dials to the specified address and asserts it has the given node id.
func (d Dialer) DialAddressID(ctx context.Context, address string, id storj.NodeID) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.TLSOptions == nil {
return nil, Error.New("tls options not set when required for this dial")
}
return d.dial(ctx, address, d.TLSOptions.ClientTLSConfig(id))
}
// DialAddressInsecure dials to the specified address and does not check the node id.
func (d Dialer) DialAddressInsecure(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.TLSOptions == nil {
return nil, Error.New("tls options not set when required for this dial")
}
return d.dial(ctx, address, d.TLSOptions.UnverifiedClientTLSConfig())
}
// DialAddressUnencrypted dials to the specified address without tls.
func (d Dialer) DialAddressUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
return d.dialInsecure(ctx, address)
}

75
pkg/rpc/dial_drpc.go Normal file
View File

@ -0,0 +1,75 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build drpc
package rpc
import (
"context"
"crypto/tls"
"storj.io/drpc/drpcconn"
)
const drpcHeader = "DRPC!!!1"
// dial performs the dialing to the drpc endpoint with tls.
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
// set up an error to expire when the context is canceled
errCh := make(chan error, 2)
go func() {
<-ctx.Done()
errCh <- ctx.Err()
}()
// open the tcp socket to the address
rawConn, err := d.dialContext(ctx, address)
if err != nil {
return nil, Error.Wrap(err)
}
// write the header bytes before the tls handshake
if _, err := rawConn.Write([]byte(drpcHeader)); err != nil {
_ = rawConn.Close()
return nil, Error.Wrap(err)
}
// perform the handshake racing with the context closing
conn := tls.Client(rawConn, tlsConfig)
go func() { errCh <- conn.Handshake() }()
// see which wins and close the raw conn if there was any error. we can't
// close the tls connection concurrently with handshakes or it sometimes
// will panic. cool, huh?
if err := <-errCh; err != nil {
_ = rawConn.Close()
return nil, Error.Wrap(err)
}
return &Conn{
raw: drpcconn.New(conn),
state: conn.ConnectionState(),
}, nil
}
// dialInsecure performs dialing to the drpc endpoint with no tls.
func (d Dialer) dialInsecure(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
// open the tcp socket to the address
conn, err := d.dialContext(ctx, address)
if err != nil {
return nil, Error.Wrap(err)
}
// write the header bytes before the tls handshake
if _, err := conn.Write([]byte(drpcHeader)); err != nil {
_ = conn.Close()
return nil, Error.Wrap(err)
}
return &Conn{raw: drpcconn.New(conn)}, nil
}

94
pkg/rpc/dial_grpc.go Normal file
View File

@ -0,0 +1,94 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build !drpc
package rpc
import (
"context"
"crypto/tls"
"net"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// dial performs the dialing to the grpc endpoint with tls.
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
creds := &captureStateCreds{TransportCredentials: credentials.NewTLS(tlsConfig)}
conn, err := grpc.DialContext(ctx, address,
grpc.WithTransportCredentials(creds),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
state, ok := creds.Get()
if !ok {
_ = conn.Close()
return nil, Error.New("unable to get tls connection state when dialing")
}
return &Conn{
raw: conn,
state: state,
}, nil
}
// dialInsecure performs dialing to the grpc endpoint with no tls.
func (d Dialer) dialInsecure(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
conn, err := grpc.DialContext(ctx, address,
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
return &Conn{raw: conn}, nil
}
// captureStateCreds captures the tls connection state from a client/server handshake.
type captureStateCreds struct {
credentials.TransportCredentials
once sync.Once
state tls.ConnectionState
}
// Get returns the stored tls connection state.
func (c *captureStateCreds) Get() (state tls.ConnectionState, ok bool) {
c.once.Do(func() { ok = true })
return c.state, ok
}
// ClientHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, authority, rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state = tlsInfo.State })
}
return conn, auth, err
}
// ServerHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ServerHandshake(rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ServerHandshake(rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state = tlsInfo.State })
}
return conn, auth, err
}

203
pkg/rpc/gen.go Normal file
View File

@ -0,0 +1,203 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
//
// This code generates the compat_drpc and compat_grpc files by reading in
// protobuf definitions. Its purpose is to generate a bunch of type aliases
// and forwarding functions so that a build tag transparently swaps out the
// concrete implementations of the rpcs.
// +build ignore
package main
import (
"bytes"
"fmt"
"go/format"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/zeebo/errs"
)
func main() {
if err := run(); err != nil {
log.Fatalf("%+v", err)
}
}
func usage() error {
return errs.New("usage: %s <dir> <drpc|grpc> <output file>", os.Args[0])
}
func run() error {
if len(os.Args) < 4 {
return usage()
}
clients, err := findClientsInDir(os.Args[1])
if err != nil {
return errs.Wrap(err)
}
info, ok := infos[os.Args[2]]
if !ok {
return usage()
}
return generate(clients, info, os.Args[3])
}
//
// info about the difference between generated files
//
type generateInfo struct {
Name string
Import string
Prefix string
Conn string
Tag string
}
var infos = map[string]generateInfo{
"drpc": {
Name: "drpc",
Import: "storj.io/drpc/drpcconn",
Prefix: "DRPC",
Conn: "drpcconn.Conn",
Tag: "drpc",
},
"grpc": {
Name: "grpc",
Import: "google.golang.org/grpc", // the saddest newline
Prefix: "",
Conn: "grpc.ClientConn",
Tag: "!drpc",
},
}
//
// main code to generate a compatability file
//
func generate(clients []string, info generateInfo, output string) (err error) {
var buf bytes.Buffer
p := printer{w: &buf}
P := p.P
Pf := p.Pf
P("// Copyright (C) 2019 Storj Labs, Inc.")
P("// See LICENSE for copying information.")
P()
P("// +build", info.Tag)
P()
P("package rpc")
P()
P("import (")
Pf("%q", info.Import)
if !strings.HasPrefix(info.Import, "storj.io/") {
P()
}
Pf("%q", "storj.io/storj/pkg/pb")
P(")")
P()
P("// RawConn is a type alias to a", info.Name, "client connection")
P("type RawConn =", info.Conn)
P()
P("type (")
for _, client := range clients {
P("//", client, "is an alias to the", info.Name, "client interface")
Pf("%s = pb.%s%s", client, info.Prefix, client)
P()
}
P(")")
for _, client := range clients {
P()
Pf("// New%s returns the %s version of a %s", client, info.Name, client)
Pf("func New%s(rc *RawConn) %s {", client, client)
Pf("return pb.New%s%s(rc)", info.Prefix, client)
P("}")
P()
Pf("// %s returns a %s for this connection", client, client)
Pf("func (c *Conn) %s() %s {", client, client)
Pf("return New%s(c.raw)", client)
P("}")
}
if err := p.Err(); err != nil {
return errs.Wrap(err)
}
fmtd, err := format.Source(buf.Bytes())
if err != nil {
return errs.Wrap(err)
}
return errs.Wrap(ioutil.WriteFile(output, fmtd, 0644))
}
//
// hacky code to find all the rpc clients in a go package
//
var clientRegex = regexp.MustCompile("^type (.*Client) interface {$")
func findClientsInDir(dir string) (clients []string, err error) {
files, err := filepath.Glob(filepath.Join(dir, "*.pb.go"))
if err != nil {
return nil, errs.Wrap(err)
}
for _, file := range files {
fileClients, err := findClientsInFile(file)
if err != nil {
return nil, errs.Wrap(err)
}
clients = append(clients, fileClients...)
}
sort.Strings(clients)
return clients, nil
}
func findClientsInFile(file string) (clients []string, err error) {
data, err := ioutil.ReadFile(file)
if err != nil {
return nil, errs.Wrap(err)
}
for _, line := range bytes.Split(data, []byte("\n")) {
switch client := clientRegex.FindSubmatch(line); {
case client == nil:
case bytes.HasPrefix(client[1], []byte("DRPC")):
case bytes.Contains(client[1], []byte("_")):
default:
clients = append(clients, string(client[1]))
}
}
return clients, nil
}
//
// helper to check errors while printing
//
type printer struct {
w io.Writer
err error
}
func (p *printer) P(args ...interface{}) {
if p.err == nil {
_, p.err = fmt.Fprintln(p.w, args...)
}
}
func (p *printer) Pf(format string, args ...interface{}) {
if p.err == nil {
_, p.err = fmt.Fprintf(p.w, format+"\n", args...)
}
}
func (p *printer) Err() error {
return p.err
}