From 5d80e22af93adb0661ac766f0f033d469752c5c7 Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Wed, 15 Jan 2020 23:59:22 +0200 Subject: [PATCH] private/tagsql: implement wrapper for sql.DB Wrapper adds tracing and fixes context usage issues. Change-Id: Ie6f7650eac87e2a2b64b760198498ba5857ad535 --- private/dbutil/pgutil/schema.go | 9 +- private/tagsql/basic_test.go | 105 +++++++++++++++ private/tagsql/conn.go | 102 ++++++++++++++ private/tagsql/db.go | 229 ++++++++++++++++++++++++++++++++ private/tagsql/db_test.go | 66 +++++++++ private/tagsql/detect.go | 68 ++++++++++ private/tagsql/stmt.go | 79 +++++++++++ private/tagsql/tx.go | 113 ++++++++++++++++ 8 files changed, 767 insertions(+), 4 deletions(-) create mode 100644 private/tagsql/basic_test.go create mode 100644 private/tagsql/conn.go create mode 100644 private/tagsql/db.go create mode 100644 private/tagsql/db_test.go create mode 100644 private/tagsql/detect.go create mode 100644 private/tagsql/stmt.go create mode 100644 private/tagsql/tx.go diff --git a/private/dbutil/pgutil/schema.go b/private/dbutil/pgutil/schema.go index f1aa68b45..4d7805929 100644 --- a/private/dbutil/pgutil/schema.go +++ b/private/dbutil/pgutil/schema.go @@ -6,9 +6,9 @@ package pgutil import ( "context" + "crypto/rand" "database/sql" "encoding/hex" - "math/rand" "net/url" "strings" @@ -18,9 +18,10 @@ import ( // CreateRandomTestingSchemaName creates a random schema name string. func CreateRandomTestingSchemaName(n int) string { data := make([]byte, n) - - // math/rand.Read() always returns a nil error so there's no need to handle the error. - _, _ = rand.Read(data) + _, err := rand.Read(data) + if err != nil { + panic(err) + } return hex.EncodeToString(data) } diff --git a/private/tagsql/basic_test.go b/private/tagsql/basic_test.go new file mode 100644 index 000000000..ec2fcfd45 --- /dev/null +++ b/private/tagsql/basic_test.go @@ -0,0 +1,105 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + + "storj.io/common/testcontext" + "storj.io/storj/private/tagsql" +) + +func TestDetect(t *testing.T) { + run(t, func(parentctx *testcontext.Context, t *testing.T, rawdb *sql.DB, support tagsql.ContextSupport) { + db := tagsql.Wrap(rawdb) + + _, err := db.ExecContext(parentctx, "CREATE TABLE example (num INT)") + require.NoError(t, err) + _, err = db.ExecContext(parentctx, "INSERT INTO example (num) values (1)") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(parentctx) + cancel() + + var verify func(t require.TestingT, err error, msgAndArgs ...interface{}) + if support.Basic() { + verify = require.Error + } else { + verify = require.NoError + } + + err = db.PingContext(ctx) + verify(t, err) + + _, err = db.ExecContext(ctx, "INSERT INTO example (num) values (1)") + verify(t, err) + + row := db.QueryRowContext(ctx, "select num from example") + var value int64 + err = row.Scan(&value) + verify(t, err) + + var rows *sql.Rows + rows, err = db.QueryContext(ctx, "select num from example") + verify(t, err) + if rows != nil { + require.NoError(t, rows.Close()) + } + + if support.Transactions() { + var tx tagsql.Tx + tx, err = db.Begin(ctx) + require.Error(t, err) + if tx != nil { + require.NoError(t, tx.Rollback()) + } + + tx, err = db.BeginTx(ctx) + require.Error(t, err) + if tx != nil { + require.NoError(t, tx.Rollback()) + } + } + + var verifyTx func(t require.TestingT, err error, msgAndArgs ...interface{}) + if support.Transactions() { + verifyTx = require.Error + } else { + verifyTx = require.NoError + } + + for _, alt := range []bool{false, true} { + t.Log("Transactions", alt) + var tx tagsql.Tx + if alt { + tx, err = db.Begin(parentctx) + } else { + tx, err = db.BeginTx(parentctx) + } + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, "INSERT INTO example (num) values (1)") + verifyTx(t, err) + + var rows *sql.Rows + rows, err = tx.QueryContext(ctx, "select num from example") + verifyTx(t, err) + if rows != nil { + require.NoError(t, rows.Close()) + } + + row := tx.QueryRowContext(ctx, "select num from example") + var value int64 + // lib/pq seems to stall here for some reason? + err = row.Scan(&value) + verifyTx(t, err) + + require.NoError(t, tx.Commit()) + } + }) +} diff --git a/private/tagsql/conn.go b/private/tagsql/conn.go new file mode 100644 index 000000000..b483086f9 --- /dev/null +++ b/private/tagsql/conn.go @@ -0,0 +1,102 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql + +import ( + "context" + "database/sql" + + "storj.io/storj/pkg/traces" + "storj.io/storj/private/context2" +) + +// Conn is an interface for *sql.Conn-like connections. +type Conn interface { + BeginTx(ctx context.Context) (Tx, error) + Close() error + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PingContext(ctx context.Context) error + PrepareContext(ctx context.Context, query string) (Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Raw(ctx context.Context, f func(driverConn interface{}) error) (err error) +} + +// TODO: +// Is there a way to call non-context versions on *sql.Conn? +// The pessimistic and safer assumption is that using any context may break +// lib/pq internally. It might be fine, however it's unclear, how fine it is. + +// sqlConn implements Conn, which optionally disables contexts. +type sqlConn struct { + conn *sql.Conn + useContext bool + useTxContext bool +} + +func (s *sqlConn) BeginTx(ctx context.Context) (Tx, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + + tx, err := s.conn.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, nil +} + +func (s *sqlConn) Close() error { + return s.conn.Close() +} + +func (s *sqlConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + return s.conn.ExecContext(ctx, query, args...) +} + +func (s *sqlConn) PingContext(ctx context.Context) error { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + return s.conn.PingContext(ctx) +} + +func (s *sqlConn) PrepareContext(ctx context.Context, query string) (Stmt, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + stmt, err := s.conn.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil +} + +func (s *sqlConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + return s.conn.QueryContext(ctx, query, args...) +} + +func (s *sqlConn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + ctx = context2.WithoutCancellation(ctx) + } + return s.conn.QueryRowContext(ctx, query, args...) +} + +func (s *sqlConn) Raw(ctx context.Context, f func(driverConn interface{}) error) (err error) { + traces.Tag(ctx, traces.TagDB) + return s.conn.Raw(f) +} diff --git a/private/tagsql/db.go b/private/tagsql/db.go new file mode 100644 index 000000000..e01cf7ecc --- /dev/null +++ b/private/tagsql/db.go @@ -0,0 +1,229 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +// Package tagsql implements a tagged wrapper for databases. +// +// This package also handles hides context cancellation from database drivers +// that don't support it. +package tagsql + +import ( + "context" + "database/sql" + "database/sql/driver" + "time" + + "storj.io/storj/pkg/traces" + "storj.io/storj/private/context2" +) + +// Wrap turns a *sql.DB into a DB-matching interface. +func Wrap(db *sql.DB) DB { + support, err := DetectContextSupport(db) + if err != nil { + // When we reach here it is definitely a programmer error. + // Add any new database drivers into DetectContextSupport + panic(err) + } + + return &sqlDB{ + db: db, + useContext: support.Basic(), + useTxContext: support.Transactions(), + } +} + +// WithoutContext turns a *sql.DB into a DB-matching that redirects context calls to regular calls. +func WithoutContext(db *sql.DB) DB { + return &sqlDB{db: db, useContext: false, useTxContext: false} +} + +// AllowContext turns a *sql.DB into a DB which uses context calls. +func AllowContext(db *sql.DB) DB { + return &sqlDB{db: db, useContext: true, useTxContext: true} +} + +// DB implements a wrapper for *sql.DB-like database. +// +// The wrapper adds tracing to all calls. +// It also adds context handling compatibility for different databases. +type DB interface { + // To be deprecated, the following take ctx as argument, + // however do not pass it forward to the underlying database. + Begin(ctx context.Context) (Tx, error) + Driver(ctx context.Context) driver.Driver + Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Ping(ctx context.Context) error + Prepare(ctx context.Context, query string) (Stmt, error) + Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row + + BeginTx(ctx context.Context) (Tx, error) + Conn(ctx context.Context) (Conn, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PingContext(ctx context.Context) error + PrepareContext(ctx context.Context, query string) (Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + + Close() error + + SetConnMaxLifetime(d time.Duration) + SetMaxIdleConns(n int) + SetMaxOpenConns(n int) + Stats() sql.DBStats +} + +// sqlDB implements DB, which optionally disables contexts. +type sqlDB struct { + db *sql.DB + useContext bool + useTxContext bool +} + +func (s *sqlDB) Begin(ctx context.Context) (Tx, error) { + traces.Tag(ctx, traces.TagDB) + tx, err := s.db.Begin() + if err != nil { + return nil, err + } + return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err +} + +func (s *sqlDB) BeginTx(ctx context.Context) (Tx, error) { + traces.Tag(ctx, traces.TagDB) + + var tx *sql.Tx + var err error + if !s.useContext { + tx, err = s.db.Begin() + } else { + tx, err = s.db.BeginTx(ctx, nil) + } + + if err != nil { + return nil, err + } + + return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err +} + +func (s *sqlDB) Close() error { + return s.db.Close() +} + +func (s *sqlDB) Conn(ctx context.Context) (Conn, error) { + traces.Tag(ctx, traces.TagDB) + var conn *sql.Conn + var err error + if !s.useContext { + // Uses WithoutCancellation, because there isn't an underlying call that doesn't take a context. + conn, err = s.db.Conn(context2.WithoutCancellation(ctx)) + } else { + conn, err = s.db.Conn(ctx) + } + if err != nil { + return nil, err + } + return &sqlConn{conn: conn, useContext: s.useContext, useTxContext: s.useTxContext}, nil +} + +func (s *sqlDB) Driver(ctx context.Context) driver.Driver { + traces.Tag(ctx, traces.TagDB) + return s.db.Driver() +} + +func (s *sqlDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + return s.db.Exec(query, args...) +} + +func (s *sqlDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.db.Exec(query, args...) + } + return s.db.ExecContext(ctx, query, args...) +} + +func (s *sqlDB) Ping(ctx context.Context) error { + traces.Tag(ctx, traces.TagDB) + return s.db.Ping() +} + +func (s *sqlDB) PingContext(ctx context.Context) error { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.db.Ping() + } + return s.db.PingContext(ctx) +} + +func (s *sqlDB) Prepare(ctx context.Context, query string) (Stmt, error) { + traces.Tag(ctx, traces.TagDB) + stmt, err := s.db.Prepare(query) + if err != nil { + return nil, err + } + return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil +} + +func (s *sqlDB) PrepareContext(ctx context.Context, query string) (Stmt, error) { + traces.Tag(ctx, traces.TagDB) + var stmt *sql.Stmt + var err error + if !s.useContext { + stmt, err = s.db.Prepare(query) + if err != nil { + return nil, err + } + } else { + stmt, err = s.db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + } + return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil +} + +func (s *sqlDB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + return s.db.Query(query, args...) +} + +func (s *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.db.Query(query, args...) + } + return s.db.QueryContext(ctx, query, args...) +} + +func (s *sqlDB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + return s.db.QueryRow(query, args...) +} + +func (s *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.db.QueryRow(query, args...) + } + return s.db.QueryRowContext(ctx, query, args...) +} + +func (s *sqlDB) SetConnMaxLifetime(d time.Duration) { + s.db.SetConnMaxLifetime(d) +} + +func (s *sqlDB) SetMaxIdleConns(n int) { + s.db.SetMaxIdleConns(n) +} + +func (s *sqlDB) SetMaxOpenConns(n int) { + s.db.SetMaxOpenConns(n) +} + +func (s *sqlDB) Stats() sql.DBStats { + return s.db.Stats() +} diff --git a/private/tagsql/db_test.go b/private/tagsql/db_test.go new file mode 100644 index 000000000..e86f3e6b8 --- /dev/null +++ b/private/tagsql/db_test.go @@ -0,0 +1,66 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql_test + +import ( + "database/sql" + "testing" + + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + + "storj.io/common/testcontext" + "storj.io/storj/private/dbutil/cockroachutil" + "storj.io/storj/private/dbutil/pgutil" + "storj.io/storj/private/dbutil/pgutil/pgtest" + "storj.io/storj/private/tagsql" +) + +func run(t *testing.T, fn func(*testcontext.Context, *testing.T, *sql.DB, tagsql.ContextSupport)) { + t.Helper() + + t.Run("mattn-sqlite3", func(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer ctx.Check(db.Close) + + fn(ctx, t, db, tagsql.SupportBasic) + }) + + t.Run("lib-pq-postgres", func(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + if *pgtest.ConnStr == "" { + t.Skipf("postgresql flag missing, example:\n-postgres-test-db=%s", pgtest.DefaultConnStr) + } + + db, err := pgutil.OpenUnique(ctx, *pgtest.ConnStr, "detect") + require.NoError(t, err) + defer ctx.Check(db.Close) + + fn(ctx, t, db.DB, tagsql.SupportNone) + }) + + t.Run("lib-pq-cockroach", func(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + if *pgtest.CrdbConnStr == "" { + t.Skipf("postgresql flag missing, example:\n-cockroach-test-db=%s", pgtest.DefaultCrdbConnStr) + } + + db, err := cockroachutil.OpenUnique(ctx, *pgtest.CrdbConnStr, "detect") + require.NoError(t, err) + defer ctx.Check(db.Close) + + fn(ctx, t, db.DB, tagsql.SupportNone) + }) +} diff --git a/private/tagsql/detect.go b/private/tagsql/detect.go new file mode 100644 index 000000000..6332f85e5 --- /dev/null +++ b/private/tagsql/detect.go @@ -0,0 +1,68 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql + +import ( + "database/sql" + "reflect" + + "github.com/zeebo/errs" +) + +// Currently lib/pq has known issues with contexts in general. +// For lib/pq context methods will be completely disabled. +// +// A few issues: +// https://github.com/lib/pq/issues/874 +// https://github.com/lib/pq/issues/908 +// https://github.com/lib/pq/issues/731 +// +// mattn/go-sqlite3 seems to work with contexts on the most part, +// except in transactions. For them, we need to disable. +// https://github.com/mattn/go-sqlite3/issues/769 +// +// Currently we don't have data on whether github.com/jackc/pgx supports them properly. + +// ContextSupport returns the level of context support a driver has. +type ContextSupport byte + +// Constants for defining context level support. +const ( + SupportBasic ContextSupport = 1 << 0 + SupportTransactions ContextSupport = 1 << 1 + + SupportNone ContextSupport = 0 + SupportAll ContextSupport = SupportBasic | SupportTransactions +) + +// Basic returns true when driver supports basic contexts. +func (v ContextSupport) Basic() bool { + return v&SupportBasic == SupportBasic +} + +// Transactions returns true when driver supports contexts inside transactions. +func (v ContextSupport) Transactions() bool { + return v&SupportTransactions == SupportTransactions +} + +// DetectContextSupport detects *sql.DB driver without importing the specific packages. +func DetectContextSupport(db *sql.DB) (ContextSupport, error) { + // We're using reflect so we don't have to import these packages + // into the binary. + typ := reflect.TypeOf(db.Driver()) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + switch { + case typ.PkgPath() == "github.com/mattn/go-sqlite3" && typ.Name() == "SQLiteDriver": + return SupportBasic, nil + case typ.PkgPath() == "github.com/lib/pq" && typ.Name() == "Driver" || + // internally uses lib/pq + typ.PkgPath() == "storj.io/storj/private/dbutil/cockroachutil" && typ.Name() == "Driver": + return SupportNone, nil + default: + return SupportNone, errs.New("sql driver %q %q unsupported", typ.PkgPath(), typ.Name()) + } +} diff --git a/private/tagsql/stmt.go b/private/tagsql/stmt.go new file mode 100644 index 000000000..1268408cb --- /dev/null +++ b/private/tagsql/stmt.go @@ -0,0 +1,79 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql + +import ( + "context" + "database/sql" + + "storj.io/storj/pkg/traces" +) + +// Stmt is an interface for *sql.Stmt. +type Stmt interface { + // Exec and other methods take a context for tracing + // purposes, but do not pass the context to the underlying database query. + Exec(ctx context.Context, args ...interface{}) (sql.Result, error) + Query(ctx context.Context, args ...interface{}) (*sql.Rows, error) + QueryRow(ctx context.Context, args ...interface{}) *sql.Row + + // ExecContext and other Context methods take a context for tracing and also + // pass the context to the underlying database, if this tagsql instance is + // configured to do so. (By default, lib/pq does not ever, and + // mattn/go-sqlite3 does not for transactions). + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + + Close() error +} + +// sqlStmt implements Stmt, which optionally disables contexts. +type sqlStmt struct { + stmt *sql.Stmt + useContext bool +} + +func (s *sqlStmt) Close() error { + return s.stmt.Close() +} + +func (s *sqlStmt) Exec(ctx context.Context, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + return s.stmt.Exec(args...) +} + +func (s *sqlStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.stmt.Exec(args...) + } + return s.stmt.ExecContext(ctx, args...) +} + +func (s *sqlStmt) Query(ctx context.Context, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + return s.stmt.Query(args...) +} + +func (s *sqlStmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.stmt.Query(args...) + } + return s.stmt.QueryContext(ctx, args...) +} + +func (s *sqlStmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + return s.stmt.QueryRow(args...) +} + +func (s *sqlStmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.stmt.QueryRow(args...) + } + return s.stmt.QueryRowContext(ctx, args...) +} diff --git a/private/tagsql/tx.go b/private/tagsql/tx.go new file mode 100644 index 000000000..b06bc129d --- /dev/null +++ b/private/tagsql/tx.go @@ -0,0 +1,113 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// See LICENSE for copying information. + +package tagsql + +import ( + "context" + "database/sql" + + "storj.io/storj/pkg/traces" +) + +// Tx is an interface for *sql.Tx-like transactions. +type Tx interface { + // Exec and other methods take a context for tracing + // purposes, but do not pass the context to the underlying database query + Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Prepare(ctx context.Context, query string) (Stmt, error) + Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row + + // ExecContext and other Context methods take a context for tracing and also + // pass the context to the underlying database, if this tagsql instance is + // configured to do so. (By default, lib/pq does not ever, and + // mattn/go-sqlite3 does not for transactions). + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + + Commit() error + Rollback() error +} + +// sqlTx implements Tx, which optionally disables contexts. +type sqlTx struct { + tx *sql.Tx + useContext bool +} + +func (s *sqlTx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + return s.tx.Exec(query, args...) +} + +func (s *sqlTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.tx.Exec(query, args...) + } + return s.tx.ExecContext(ctx, query, args...) +} + +func (s *sqlTx) Prepare(ctx context.Context, query string) (Stmt, error) { + traces.Tag(ctx, traces.TagDB) + stmt, err := s.tx.Prepare(query) + if err != nil { + return nil, err + } + return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil +} + +func (s *sqlTx) PrepareContext(ctx context.Context, query string) (Stmt, error) { + traces.Tag(ctx, traces.TagDB) + var stmt *sql.Stmt + var err error + if !s.useContext { + stmt, err = s.tx.Prepare(query) + if err != nil { + return nil, err + } + } else { + stmt, err = s.tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + } + return &sqlStmt{stmt: stmt, useContext: s.useContext}, err +} + +func (s *sqlTx) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + return s.tx.Query(query, args...) +} + +func (s *sqlTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.tx.Query(query, args...) + } + return s.tx.QueryContext(ctx, query, args...) +} + +func (s *sqlTx) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + return s.tx.QueryRow(query, args...) +} + +func (s *sqlTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + traces.Tag(ctx, traces.TagDB) + if !s.useContext { + return s.tx.QueryRow(query, args...) + } + return s.tx.QueryRowContext(ctx, query, args...) +} + +func (s *sqlTx) Commit() error { + return s.tx.Commit() +} + +func (s *sqlTx) Rollback() error { + return s.tx.Rollback() +}