diff --git a/private/tagsql/db.go b/private/tagsql/db.go index 7fc625d5f..c17468c6e 100644 --- a/private/tagsql/db.go +++ b/private/tagsql/db.go @@ -105,7 +105,7 @@ func (s *sqlDB) Begin(ctx context.Context) (Tx, error) { if err != nil { return nil, err } - return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err + return &sqlTx{tx: leakCheckTx(tx), useContext: s.useContext && s.useTxContext}, err } func (s *sqlDB) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (Tx, error) { @@ -126,7 +126,7 @@ func (s *sqlDB) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (Tx, erro return nil, err } - return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err + return &sqlTx{tx: leakCheckTx(tx), useContext: s.useContext && s.useTxContext}, err } func (s *sqlDB) Close() error { @@ -146,7 +146,7 @@ func (s *sqlDB) Conn(ctx context.Context) (Conn, error) { if err != nil { return nil, err } - return &sqlConn{conn: conn, useContext: s.useContext, useTxContext: s.useTxContext}, nil + return &sqlConn{conn: leakCheckConn(conn), useContext: s.useContext, useTxContext: s.useTxContext}, nil } func (s *sqlDB) Driver() driver.Driver { @@ -185,7 +185,7 @@ func (s *sqlDB) Prepare(ctx context.Context, query string) (Stmt, error) { if err != nil { return nil, err } - return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil + return &sqlStmt{stmt: leakCheckStmt(stmt), useContext: s.useContext}, nil } func (s *sqlDB) PrepareContext(ctx context.Context, query string) (Stmt, error) { @@ -203,33 +203,38 @@ func (s *sqlDB) PrepareContext(ctx context.Context, query string) (Stmt, error) return nil, err } } - return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil + return &sqlStmt{stmt: leakCheckStmt(stmt), useContext: s.useContext}, nil } -func (s *sqlDB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (s *sqlDB) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { traces.Tag(ctx, traces.TagDB) - return s.db.Query(query, args...) + rows, err = s.db.Query(query, args...) + return leakCheckRows(rows), err } -func (s *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (s *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { traces.Tag(ctx, traces.TagDB) if !s.useContext { - return s.db.Query(query, args...) + rows, err = s.db.Query(query, args...) + } else { + rows, err = s.db.QueryContext(ctx, query, args...) } - return s.db.QueryContext(ctx, query, args...) + return leakCheckRows(rows), err } -func (s *sqlDB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (s *sqlDB) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row) { traces.Tag(ctx, traces.TagDB) - return s.db.QueryRow(query, args...) + return leakCheckRow(s.db.QueryRow(query, args...)) } -func (s *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (s *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) (row *sql.Row) { traces.Tag(ctx, traces.TagDB) if !s.useContext { - return s.db.QueryRow(query, args...) + row = s.db.QueryRow(query, args...) + } else { + row = s.db.QueryRowContext(ctx, query, args...) } - return s.db.QueryRowContext(ctx, query, args...) + return leakCheckRow(row) } func (s *sqlDB) SetConnMaxLifetime(d time.Duration) { diff --git a/private/tagsql/leak_check.go b/private/tagsql/leak_check.go new file mode 100644 index 000000000..4d252d20b --- /dev/null +++ b/private/tagsql/leak_check.go @@ -0,0 +1,111 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql + +import ( + "database/sql" + "reflect" + "runtime" + + "storj.io/private/version" +) + +func leakCheckRows(rows *sql.Rows) *sql.Rows { + if !version.Build.Release && rows != nil { + runtime.SetFinalizer(rows, ensureRowsClosed) + } + return rows +} + +func ensureRowsClosed(rows *sql.Rows) { + // this field is protected by a mutex, but fortunately for us, we don't + // have to worry because we know we are the only ones with a reference + // to this object since it's running in the finalizer. race free! + if !reflect.ValueOf(rows).Elem().FieldByName("closed").Bool() { + panic("leaked *sql.Rows value without being closed") + } +} + +func leakCheckTx(tx *sql.Tx) *sql.Tx { + if !version.Build.Release && tx != nil { + runtime.SetFinalizer(tx, ensureTxComplete) + } + return tx +} + +func ensureTxComplete(tx *sql.Tx) { + // from the docs of the struct: + // + // done transitions from 0 to 1 exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTxDone. + // Use atomic operations on value when checking value. + // + // fortunately, we're the only reference to this tx, so we don't + // have to worry about atomics. + + if reflect.ValueOf(tx).Elem().FieldByName("done").Int() != 1 { + panic("leaked *sql.Tx value without being complete") + } +} + +func leakCheckConn(conn *sql.Conn) *sql.Conn { + if !version.Build.Release && conn != nil { + runtime.SetFinalizer(conn, ensureConnComplete) + } + return conn +} + +func ensureConnComplete(conn *sql.Conn) { + // from the docs of the struct: + // + // done transitions from 0 to 1 exactly once, on close. + // Once done, all operations fail with ErrConnDone. + // Use atomic operations on value when checking value. + // + // fortunately, we're the only reference to this tx, so we don't + // have to worry about atomics. + + if reflect.ValueOf(conn).Elem().FieldByName("done").Int() != 1 { + panic("leaked *sql.Conn value without being complete") + } +} + +func leakCheckStmt(stmt *sql.Stmt) *sql.Stmt { + if !version.Build.Release && stmt != nil { + runtime.SetFinalizer(stmt, ensureStmtClosed) + } + return stmt +} + +func ensureStmtClosed(stmt *sql.Stmt) { + // this field is protected by a mutex, but fortunately for us, we don't + // have to worry because we know we are the only ones with a reference + // to this object since it's running in the finalizer. race free! + if !reflect.ValueOf(stmt).Elem().FieldByName("closed").Bool() { + panic("leaked *sql.Stmt value without being closed") + } +} + +func leakCheckRow(row *sql.Row) *sql.Row { + if !version.Build.Release && row != nil { + runtime.SetFinalizer(row, ensureRowClosed) + } + return row +} + +func ensureRowClosed(row *sql.Row) { + // check the underlying rows field, avoiding issue if it is nil. + rows := reflect.ValueOf(row).Elem().FieldByName("rows") + if rows.IsNil() { + return + } + + // this field is protected by a mutex, but fortunately for us, we don't + // have to worry because we know we are the only ones with a reference + // to this object since it's running in the finalizer. race free! + if !rows.Elem().FieldByName("closed").Bool() { + panic("leaked *sql.Rows value without being closed") + } +}