storj/pkg/rpc/dial_grpc.go
Jeff Wendling f62107d3e9
pkg/rpc: fix grpc dial timeouts (#3517)
grpc doesn't exit dials right away if the context dialer
returns an error. since that's the only spot where we were
enforcing dial timeouts, dials could just leak for an
unknown amount of time.

add a timeout above the grpc dial because that's the documented
way that grpc expected to be canceled.

Change-Id: Ic47ac61ce8a5f721510cc2c4584f63d43fe4f2d5
2019-11-06 16:42:20 -07:00

109 lines
2.9 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build grpc
package rpc
import (
"context"
"crypto/tls"
"net"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// dial performs the dialing to the grpc endpoint with tls.
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
defer cancel()
}
creds := &captureStateCreds{TransportCredentials: credentials.NewTLS(tlsConfig)}
conn, err := grpc.DialContext(ctx, address,
grpc.WithTransportCredentials(creds),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
state, ok := creds.Get()
if !ok {
_ = conn.Close()
return nil, Error.New("unable to get tls connection state when dialing")
}
return &Conn{
raw: conn,
state: state,
}, nil
}
// dialUnencrypted performs dialing to the grpc endpoint with no tls.
func (d Dialer) dialUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
defer cancel()
}
conn, err := grpc.DialContext(ctx, address,
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
return &Conn{raw: conn}, nil
}
// captureStateCreds captures the tls connection state from a client/server handshake.
type captureStateCreds struct {
credentials.TransportCredentials
once sync.Once
state tls.ConnectionState
ok bool
}
// Get returns the stored tls connection state.
func (c *captureStateCreds) Get() (state tls.ConnectionState, ok bool) {
c.once.Do(func() {})
return c.state, c.ok
}
// ClientHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, authority, rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
}
return conn, auth, err
}
// ServerHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ServerHandshake(rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ServerHandshake(rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
}
return conn, auth, err
}