From 4fab22d691cabb7e996db05f89a5b0472927a7a1 Mon Sep 17 00:00:00 2001 From: Jeff Wendling Date: Fri, 4 Oct 2019 11:12:37 -0600 Subject: [PATCH] pkg/rpc: don't leak goroutines during a drpc dial we spawned a goroutine to wait on the context's done channel sending the error afterward, but we forgot to ensure the context was eventually done, so the goroutine would be leaked until then. instead, we can just do a select on two channels to get the error rather than spawn a goroutine which makes it impossible to leak a goroutine. Change-Id: I2fdba206ae6ff7a3441b00708b86b36dfeece2b5 --- pkg/rpc/dial_drpc.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/rpc/dial_drpc.go b/pkg/rpc/dial_drpc.go index 74cc8647a..75674cb8a 100644 --- a/pkg/rpc/dial_drpc.go +++ b/pkg/rpc/dial_drpc.go @@ -18,13 +18,6 @@ const drpcHeader = "DRPC!!!1" func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) { defer mon.Task()(&ctx)(&err) - // set up an error to expire when the context is canceled - errCh := make(chan error, 2) - go func() { - <-ctx.Done() - errCh <- ctx.Err() - }() - // open the tcp socket to the address rawConn, err := d.dialContext(ctx, address) if err != nil { @@ -37,14 +30,21 @@ func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) return nil, Error.Wrap(err) } - // perform the handshake racing with the context closing + // perform the handshake racing with the context closing. we use a buffer + // of size 1 so that the handshake can proceed even if no one is reading. + errCh := make(chan error, 1) conn := tls.Client(rawConn, tlsConfig) go func() { errCh <- conn.Handshake() }() // see which wins and close the raw conn if there was any error. we can't // close the tls connection concurrently with handshakes or it sometimes // will panic. cool, huh? - if err := <-errCh; err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-errCh: + } + if err != nil { _ = rawConn.Close() return nil, Error.Wrap(err) }