storj/pkg/rpc/dial_grpc.go

109 lines
2.9 KiB
Go
Raw Normal View History

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
2019-11-01 21:43:24 +00:00
// +build grpc
package rpc
import (
"context"
"crypto/tls"
"net"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// dial performs the dialing to the grpc endpoint with tls.
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
defer cancel()
}
creds := &captureStateCreds{TransportCredentials: credentials.NewTLS(tlsConfig)}
conn, err := grpc.DialContext(ctx, address,
grpc.WithTransportCredentials(creds),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
state, ok := creds.Get()
if !ok {
_ = conn.Close()
return nil, Error.New("unable to get tls connection state when dialing")
}
return &Conn{
raw: conn,
state: state,
}, nil
}
// dialUnencrypted performs dialing to the grpc endpoint with no tls.
func (d Dialer) dialUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
defer mon.Task()(&ctx)(&err)
if d.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
defer cancel()
}
conn, err := grpc.DialContext(ctx, address,
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(d.dialContext))
if err != nil {
return nil, Error.Wrap(err)
}
return &Conn{raw: conn}, nil
}
// captureStateCreds captures the tls connection state from a client/server handshake.
type captureStateCreds struct {
credentials.TransportCredentials
once sync.Once
state tls.ConnectionState
ok bool
}
// Get returns the stored tls connection state.
func (c *captureStateCreds) Get() (state tls.ConnectionState, ok bool) {
c.once.Do(func() {})
return c.state, c.ok
}
// ClientHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, authority, rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
}
return conn, auth, err
}
// ServerHandshake dispatches to the underlying credentials and tries to store the
// connection state if possible.
func (c *captureStateCreds) ServerHandshake(rawConn net.Conn) (
net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ServerHandshake(rawConn)
if tlsInfo, ok := auth.(credentials.TLSInfo); ok {
c.once.Do(func() { c.state, c.ok = tlsInfo.State, true })
}
return conn, auth, err
}