f62107d3e9
grpc doesn't exit dials right away if the context dialer returns an error. since that's the only spot where we were enforcing dial timeouts, dials could just leak for an unknown amount of time. add a timeout above the grpc dial because that's the documented way that grpc expected to be canceled. Change-Id: Ic47ac61ce8a5f721510cc2c4584f63d43fe4f2d5
109 lines
2.9 KiB
Go
109 lines
2.9 KiB
Go
// Copyright (C) 2019 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
// +build grpc
|
|
|
|
package rpc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"sync"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
)
|
|
|
|
// dial performs the dialing to the grpc endpoint with tls.
|
|
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
|
|
defer mon.Task()(&ctx)(&err)
|
|
|
|
if d.DialTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
creds := &captureStateCreds{TransportCredentials: credentials.NewTLS(tlsConfig)}
|
|
conn, err := grpc.DialContext(ctx, address,
|
|
grpc.WithTransportCredentials(creds),
|
|
grpc.WithBlock(),
|
|
grpc.FailOnNonTempDialError(true),
|
|
grpc.WithContextDialer(d.dialContext))
|
|
if err != nil {
|
|
return nil, Error.Wrap(err)
|
|
}
|
|
|
|
state, ok := creds.Get()
|
|
if !ok {
|
|
_ = conn.Close()
|
|
return nil, Error.New("unable to get tls connection state when dialing")
|
|
}
|
|
|
|
return &Conn{
|
|
raw: conn,
|
|
state: state,
|
|
}, nil
|
|
}
|
|
|
|
// dialUnencrypted performs dialing to the grpc endpoint with no tls.
|
|
func (d Dialer) dialUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
|
|
defer mon.Task()(&ctx)(&err)
|
|
|
|
if d.DialTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
conn, err := grpc.DialContext(ctx, address,
|
|
grpc.WithInsecure(),
|
|
grpc.WithBlock(),
|
|
grpc.FailOnNonTempDialError(true),
|
|
grpc.WithContextDialer(d.dialContext))
|
|
if err != nil {
|
|
return nil, Error.Wrap(err)
|
|
}
|
|
|
|
return &Conn{raw: conn}, nil
|
|
}
|
|
|
|
// captureStateCreds captures the tls connection state from a client/server handshake.
|
|
type captureStateCreds struct {
|
|
credentials.TransportCredentials
|
|
once sync.Once
|
|
state tls.ConnectionState
|
|
ok bool
|
|
}
|
|
|
|
// Get returns the stored tls connection state.
|
|
func (c *captureStateCreds) Get() (state tls.ConnectionState, ok bool) {
|
|
c.once.Do(func() {})
|
|
return c.state, c.ok
|
|
}
|
|
|
|
// ClientHandshake dispatches to the underlying credentials and tries to store the
|
|
// connection state if possible.
|
|
func (c *captureStateCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (
|
|
net.Conn, credentials.AuthInfo, error) {
|
|
|
|
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, authority, rawConn)
|
|
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
|
|
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
|
|
}
|
|
return conn, auth, err
|
|
}
|
|
|
|
// ServerHandshake dispatches to the underlying credentials and tries to store the
|
|
// connection state if possible.
|
|
func (c *captureStateCreds) ServerHandshake(rawConn net.Conn) (
|
|
net.Conn, credentials.AuthInfo, error) {
|
|
|
|
conn, auth, err := c.TransportCredentials.ServerHandshake(rawConn)
|
|
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
|
|
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
|
|
}
|
|
return conn, auth, err
|
|
}
|