diff --git a/pkg/rpc/dial_drpc.go b/pkg/rpc/dial_drpc.go index 74cc8647a..75674cb8a 100644 --- a/pkg/rpc/dial_drpc.go +++ b/pkg/rpc/dial_drpc.go @@ -18,13 +18,6 @@ const drpcHeader = "DRPC!!!1" func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) { defer mon.Task()(&ctx)(&err) - // set up an error to expire when the context is canceled - errCh := make(chan error, 2) - go func() { - <-ctx.Done() - errCh <- ctx.Err() - }() - // open the tcp socket to the address rawConn, err := d.dialContext(ctx, address) if err != nil { @@ -37,14 +30,21 @@ func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) return nil, Error.Wrap(err) } - // perform the handshake racing with the context closing + // perform the handshake racing with the context closing. we use a buffer + // of size 1 so that the handshake can proceed even if no one is reading. + errCh := make(chan error, 1) conn := tls.Client(rawConn, tlsConfig) go func() { errCh <- conn.Handshake() }() // see which wins and close the raw conn if there was any error. we can't // close the tls connection concurrently with handshakes or it sometimes // will panic. cool, huh? - if err := <-errCh; err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-errCh: + } + if err != nil { _ = rawConn.Close() return nil, Error.Wrap(err) }