private/tagsql: implement wrapper for sql.DB
Wrapper adds tracing and fixes context usage issues. Change-Id: Ie6f7650eac87e2a2b64b760198498ba5857ad535
This commit is contained in:
parent
8bbb9083f0
commit
5d80e22af9
@ -6,9 +6,9 @@ package pgutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@ -18,9 +18,10 @@ import (
|
||||
// CreateRandomTestingSchemaName creates a random schema name string.
|
||||
func CreateRandomTestingSchemaName(n int) string {
|
||||
data := make([]byte, n)
|
||||
|
||||
// math/rand.Read() always returns a nil error so there's no need to handle the error.
|
||||
_, _ = rand.Read(data)
|
||||
_, err := rand.Read(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(data)
|
||||
}
|
||||
|
||||
|
105
private/tagsql/basic_test.go
Normal file
105
private/tagsql/basic_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/tagsql"
|
||||
)
|
||||
|
||||
func TestDetect(t *testing.T) {
|
||||
run(t, func(parentctx *testcontext.Context, t *testing.T, rawdb *sql.DB, support tagsql.ContextSupport) {
|
||||
db := tagsql.Wrap(rawdb)
|
||||
|
||||
_, err := db.ExecContext(parentctx, "CREATE TABLE example (num INT)")
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(parentctx, "INSERT INTO example (num) values (1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(parentctx)
|
||||
cancel()
|
||||
|
||||
var verify func(t require.TestingT, err error, msgAndArgs ...interface{})
|
||||
if support.Basic() {
|
||||
verify = require.Error
|
||||
} else {
|
||||
verify = require.NoError
|
||||
}
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
verify(t, err)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO example (num) values (1)")
|
||||
verify(t, err)
|
||||
|
||||
row := db.QueryRowContext(ctx, "select num from example")
|
||||
var value int64
|
||||
err = row.Scan(&value)
|
||||
verify(t, err)
|
||||
|
||||
var rows *sql.Rows
|
||||
rows, err = db.QueryContext(ctx, "select num from example")
|
||||
verify(t, err)
|
||||
if rows != nil {
|
||||
require.NoError(t, rows.Close())
|
||||
}
|
||||
|
||||
if support.Transactions() {
|
||||
var tx tagsql.Tx
|
||||
tx, err = db.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
if tx != nil {
|
||||
require.NoError(t, tx.Rollback())
|
||||
}
|
||||
|
||||
tx, err = db.BeginTx(ctx)
|
||||
require.Error(t, err)
|
||||
if tx != nil {
|
||||
require.NoError(t, tx.Rollback())
|
||||
}
|
||||
}
|
||||
|
||||
var verifyTx func(t require.TestingT, err error, msgAndArgs ...interface{})
|
||||
if support.Transactions() {
|
||||
verifyTx = require.Error
|
||||
} else {
|
||||
verifyTx = require.NoError
|
||||
}
|
||||
|
||||
for _, alt := range []bool{false, true} {
|
||||
t.Log("Transactions", alt)
|
||||
var tx tagsql.Tx
|
||||
if alt {
|
||||
tx, err = db.Begin(parentctx)
|
||||
} else {
|
||||
tx, err = db.BeginTx(parentctx)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO example (num) values (1)")
|
||||
verifyTx(t, err)
|
||||
|
||||
var rows *sql.Rows
|
||||
rows, err = tx.QueryContext(ctx, "select num from example")
|
||||
verifyTx(t, err)
|
||||
if rows != nil {
|
||||
require.NoError(t, rows.Close())
|
||||
}
|
||||
|
||||
row := tx.QueryRowContext(ctx, "select num from example")
|
||||
var value int64
|
||||
// lib/pq seems to stall here for some reason?
|
||||
err = row.Scan(&value)
|
||||
verifyTx(t, err)
|
||||
|
||||
require.NoError(t, tx.Commit())
|
||||
}
|
||||
})
|
||||
}
|
102
private/tagsql/conn.go
Normal file
102
private/tagsql/conn.go
Normal file
@ -0,0 +1,102 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"storj.io/storj/pkg/traces"
|
||||
"storj.io/storj/private/context2"
|
||||
)
|
||||
|
||||
// Conn is an interface for *sql.Conn-like connections.
|
||||
type Conn interface {
|
||||
BeginTx(ctx context.Context) (Tx, error)
|
||||
Close() error
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PingContext(ctx context.Context) error
|
||||
PrepareContext(ctx context.Context, query string) (Stmt, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
Raw(ctx context.Context, f func(driverConn interface{}) error) (err error)
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Is there a way to call non-context versions on *sql.Conn?
|
||||
// The pessimistic and safer assumption is that using any context may break
|
||||
// lib/pq internally. It might be fine, however it's unclear, how fine it is.
|
||||
|
||||
// sqlConn implements Conn, which optionally disables contexts.
|
||||
type sqlConn struct {
|
||||
conn *sql.Conn
|
||||
useContext bool
|
||||
useTxContext bool
|
||||
}
|
||||
|
||||
func (s *sqlConn) BeginTx(ctx context.Context) (Tx, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
|
||||
tx, err := s.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlConn) Close() error {
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *sqlConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
return s.conn.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlConn) PingContext(ctx context.Context) error {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *sqlConn) PrepareContext(ctx context.Context, query string) (Stmt, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
stmt, err := s.conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
return s.conn.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlConn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
}
|
||||
return s.conn.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlConn) Raw(ctx context.Context, f func(driverConn interface{}) error) (err error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.conn.Raw(f)
|
||||
}
|
229
private/tagsql/db.go
Normal file
229
private/tagsql/db.go
Normal file
@ -0,0 +1,229 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
// Package tagsql implements a tagged wrapper for databases.
|
||||
//
|
||||
// This package also handles hides context cancellation from database drivers
|
||||
// that don't support it.
|
||||
package tagsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"storj.io/storj/pkg/traces"
|
||||
"storj.io/storj/private/context2"
|
||||
)
|
||||
|
||||
// Wrap turns a *sql.DB into a DB-matching interface.
|
||||
func Wrap(db *sql.DB) DB {
|
||||
support, err := DetectContextSupport(db)
|
||||
if err != nil {
|
||||
// When we reach here it is definitely a programmer error.
|
||||
// Add any new database drivers into DetectContextSupport
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &sqlDB{
|
||||
db: db,
|
||||
useContext: support.Basic(),
|
||||
useTxContext: support.Transactions(),
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutContext turns a *sql.DB into a DB-matching that redirects context calls to regular calls.
|
||||
func WithoutContext(db *sql.DB) DB {
|
||||
return &sqlDB{db: db, useContext: false, useTxContext: false}
|
||||
}
|
||||
|
||||
// AllowContext turns a *sql.DB into a DB which uses context calls.
|
||||
func AllowContext(db *sql.DB) DB {
|
||||
return &sqlDB{db: db, useContext: true, useTxContext: true}
|
||||
}
|
||||
|
||||
// DB implements a wrapper for *sql.DB-like database.
|
||||
//
|
||||
// The wrapper adds tracing to all calls.
|
||||
// It also adds context handling compatibility for different databases.
|
||||
type DB interface {
|
||||
// To be deprecated, the following take ctx as argument,
|
||||
// however do not pass it forward to the underlying database.
|
||||
Begin(ctx context.Context) (Tx, error)
|
||||
Driver(ctx context.Context) driver.Driver
|
||||
Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Ping(ctx context.Context) error
|
||||
Prepare(ctx context.Context, query string) (Stmt, error)
|
||||
Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
|
||||
BeginTx(ctx context.Context) (Tx, error)
|
||||
Conn(ctx context.Context) (Conn, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PingContext(ctx context.Context) error
|
||||
PrepareContext(ctx context.Context, query string) (Stmt, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
|
||||
Close() error
|
||||
|
||||
SetConnMaxLifetime(d time.Duration)
|
||||
SetMaxIdleConns(n int)
|
||||
SetMaxOpenConns(n int)
|
||||
Stats() sql.DBStats
|
||||
}
|
||||
|
||||
// sqlDB implements DB, which optionally disables contexts.
|
||||
type sqlDB struct {
|
||||
db *sql.DB
|
||||
useContext bool
|
||||
useTxContext bool
|
||||
}
|
||||
|
||||
func (s *sqlDB) Begin(ctx context.Context) (Tx, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err
|
||||
}
|
||||
|
||||
func (s *sqlDB) BeginTx(ctx context.Context) (Tx, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
|
||||
var tx *sql.Tx
|
||||
var err error
|
||||
if !s.useContext {
|
||||
tx, err = s.db.Begin()
|
||||
} else {
|
||||
tx, err = s.db.BeginTx(ctx, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sqlTx{tx: tx, useContext: s.useContext && s.useTxContext}, err
|
||||
}
|
||||
|
||||
func (s *sqlDB) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *sqlDB) Conn(ctx context.Context) (Conn, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
var conn *sql.Conn
|
||||
var err error
|
||||
if !s.useContext {
|
||||
// Uses WithoutCancellation, because there isn't an underlying call that doesn't take a context.
|
||||
conn, err = s.db.Conn(context2.WithoutCancellation(ctx))
|
||||
} else {
|
||||
conn, err = s.db.Conn(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlConn{conn: conn, useContext: s.useContext, useTxContext: s.useTxContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlDB) Driver(ctx context.Context) driver.Driver {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.db.Driver()
|
||||
}
|
||||
|
||||
func (s *sqlDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.db.Exec(query, args...)
|
||||
}
|
||||
return s.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) Ping(ctx context.Context) error {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.db.Ping()
|
||||
}
|
||||
|
||||
func (s *sqlDB) PingContext(ctx context.Context) error {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.db.Ping()
|
||||
}
|
||||
return s.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *sqlDB) Prepare(ctx context.Context, query string) (Stmt, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
stmt, err := s.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlDB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
var stmt *sql.Stmt
|
||||
var err error
|
||||
if !s.useContext {
|
||||
stmt, err = s.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
stmt, err = s.db.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlDB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.db.Query(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.db.Query(query, args...)
|
||||
}
|
||||
return s.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.db.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.db.QueryRow(query, args...)
|
||||
}
|
||||
return s.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlDB) SetConnMaxLifetime(d time.Duration) {
|
||||
s.db.SetConnMaxLifetime(d)
|
||||
}
|
||||
|
||||
func (s *sqlDB) SetMaxIdleConns(n int) {
|
||||
s.db.SetMaxIdleConns(n)
|
||||
}
|
||||
|
||||
func (s *sqlDB) SetMaxOpenConns(n int) {
|
||||
s.db.SetMaxOpenConns(n)
|
||||
}
|
||||
|
||||
func (s *sqlDB) Stats() sql.DBStats {
|
||||
return s.db.Stats()
|
||||
}
|
66
private/tagsql/db_test.go
Normal file
66
private/tagsql/db_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/dbutil/cockroachutil"
|
||||
"storj.io/storj/private/dbutil/pgutil"
|
||||
"storj.io/storj/private/dbutil/pgutil/pgtest"
|
||||
"storj.io/storj/private/tagsql"
|
||||
)
|
||||
|
||||
func run(t *testing.T, fn func(*testcontext.Context, *testing.T, *sql.DB, tagsql.ContextSupport)) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("mattn-sqlite3", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ctx.Check(db.Close)
|
||||
|
||||
fn(ctx, t, db, tagsql.SupportBasic)
|
||||
})
|
||||
|
||||
t.Run("lib-pq-postgres", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
if *pgtest.ConnStr == "" {
|
||||
t.Skipf("postgresql flag missing, example:\n-postgres-test-db=%s", pgtest.DefaultConnStr)
|
||||
}
|
||||
|
||||
db, err := pgutil.OpenUnique(ctx, *pgtest.ConnStr, "detect")
|
||||
require.NoError(t, err)
|
||||
defer ctx.Check(db.Close)
|
||||
|
||||
fn(ctx, t, db.DB, tagsql.SupportNone)
|
||||
})
|
||||
|
||||
t.Run("lib-pq-cockroach", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
if *pgtest.CrdbConnStr == "" {
|
||||
t.Skipf("postgresql flag missing, example:\n-cockroach-test-db=%s", pgtest.DefaultCrdbConnStr)
|
||||
}
|
||||
|
||||
db, err := cockroachutil.OpenUnique(ctx, *pgtest.CrdbConnStr, "detect")
|
||||
require.NoError(t, err)
|
||||
defer ctx.Check(db.Close)
|
||||
|
||||
fn(ctx, t, db.DB, tagsql.SupportNone)
|
||||
})
|
||||
}
|
68
private/tagsql/detect.go
Normal file
68
private/tagsql/detect.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
// Currently lib/pq has known issues with contexts in general.
|
||||
// For lib/pq context methods will be completely disabled.
|
||||
//
|
||||
// A few issues:
|
||||
// https://github.com/lib/pq/issues/874
|
||||
// https://github.com/lib/pq/issues/908
|
||||
// https://github.com/lib/pq/issues/731
|
||||
//
|
||||
// mattn/go-sqlite3 seems to work with contexts on the most part,
|
||||
// except in transactions. For them, we need to disable.
|
||||
// https://github.com/mattn/go-sqlite3/issues/769
|
||||
//
|
||||
// Currently we don't have data on whether github.com/jackc/pgx supports them properly.
|
||||
|
||||
// ContextSupport returns the level of context support a driver has.
|
||||
type ContextSupport byte
|
||||
|
||||
// Constants for defining context level support.
|
||||
const (
|
||||
SupportBasic ContextSupport = 1 << 0
|
||||
SupportTransactions ContextSupport = 1 << 1
|
||||
|
||||
SupportNone ContextSupport = 0
|
||||
SupportAll ContextSupport = SupportBasic | SupportTransactions
|
||||
)
|
||||
|
||||
// Basic returns true when driver supports basic contexts.
|
||||
func (v ContextSupport) Basic() bool {
|
||||
return v&SupportBasic == SupportBasic
|
||||
}
|
||||
|
||||
// Transactions returns true when driver supports contexts inside transactions.
|
||||
func (v ContextSupport) Transactions() bool {
|
||||
return v&SupportTransactions == SupportTransactions
|
||||
}
|
||||
|
||||
// DetectContextSupport detects *sql.DB driver without importing the specific packages.
|
||||
func DetectContextSupport(db *sql.DB) (ContextSupport, error) {
|
||||
// We're using reflect so we don't have to import these packages
|
||||
// into the binary.
|
||||
typ := reflect.TypeOf(db.Driver())
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
switch {
|
||||
case typ.PkgPath() == "github.com/mattn/go-sqlite3" && typ.Name() == "SQLiteDriver":
|
||||
return SupportBasic, nil
|
||||
case typ.PkgPath() == "github.com/lib/pq" && typ.Name() == "Driver" ||
|
||||
// internally uses lib/pq
|
||||
typ.PkgPath() == "storj.io/storj/private/dbutil/cockroachutil" && typ.Name() == "Driver":
|
||||
return SupportNone, nil
|
||||
default:
|
||||
return SupportNone, errs.New("sql driver %q %q unsupported", typ.PkgPath(), typ.Name())
|
||||
}
|
||||
}
|
79
private/tagsql/stmt.go
Normal file
79
private/tagsql/stmt.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"storj.io/storj/pkg/traces"
|
||||
)
|
||||
|
||||
// Stmt is an interface for *sql.Stmt.
|
||||
type Stmt interface {
|
||||
// Exec and other methods take a context for tracing
|
||||
// purposes, but do not pass the context to the underlying database query.
|
||||
Exec(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||
Query(ctx context.Context, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(ctx context.Context, args ...interface{}) *sql.Row
|
||||
|
||||
// ExecContext and other Context methods take a context for tracing and also
|
||||
// pass the context to the underlying database, if this tagsql instance is
|
||||
// configured to do so. (By default, lib/pq does not ever, and
|
||||
// mattn/go-sqlite3 does not for transactions).
|
||||
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
// sqlStmt implements Stmt, which optionally disables contexts.
|
||||
type sqlStmt struct {
|
||||
stmt *sql.Stmt
|
||||
useContext bool
|
||||
}
|
||||
|
||||
func (s *sqlStmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s *sqlStmt) Exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.stmt.Exec(args...)
|
||||
}
|
||||
|
||||
func (s *sqlStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.stmt.Exec(args...)
|
||||
}
|
||||
return s.stmt.ExecContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (s *sqlStmt) Query(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.stmt.Query(args...)
|
||||
}
|
||||
|
||||
func (s *sqlStmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.stmt.Query(args...)
|
||||
}
|
||||
return s.stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (s *sqlStmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.stmt.QueryRow(args...)
|
||||
}
|
||||
|
||||
func (s *sqlStmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.stmt.QueryRow(args...)
|
||||
}
|
||||
return s.stmt.QueryRowContext(ctx, args...)
|
||||
}
|
113
private/tagsql/tx.go
Normal file
113
private/tagsql/tx.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package tagsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"storj.io/storj/pkg/traces"
|
||||
)
|
||||
|
||||
// Tx is an interface for *sql.Tx-like transactions.
|
||||
type Tx interface {
|
||||
// Exec and other methods take a context for tracing
|
||||
// purposes, but do not pass the context to the underlying database query
|
||||
Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(ctx context.Context, query string) (Stmt, error)
|
||||
Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
|
||||
// ExecContext and other Context methods take a context for tracing and also
|
||||
// pass the context to the underlying database, if this tagsql instance is
|
||||
// configured to do so. (By default, lib/pq does not ever, and
|
||||
// mattn/go-sqlite3 does not for transactions).
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PrepareContext(ctx context.Context, query string) (Stmt, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// sqlTx implements Tx, which optionally disables contexts.
|
||||
type sqlTx struct {
|
||||
tx *sql.Tx
|
||||
useContext bool
|
||||
}
|
||||
|
||||
func (s *sqlTx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.tx.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.tx.Exec(query, args...)
|
||||
}
|
||||
return s.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) Prepare(ctx context.Context, query string) (Stmt, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
stmt, err := s.tx.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlStmt{stmt: stmt, useContext: s.useContext}, nil
|
||||
}
|
||||
|
||||
func (s *sqlTx) PrepareContext(ctx context.Context, query string) (Stmt, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
var stmt *sql.Stmt
|
||||
var err error
|
||||
if !s.useContext {
|
||||
stmt, err = s.tx.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
stmt, err = s.tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &sqlStmt{stmt: stmt, useContext: s.useContext}, err
|
||||
}
|
||||
|
||||
func (s *sqlTx) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.tx.Query(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.tx.Query(query, args...)
|
||||
}
|
||||
return s.tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
return s.tx.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
traces.Tag(ctx, traces.TagDB)
|
||||
if !s.useContext {
|
||||
return s.tx.QueryRow(query, args...)
|
||||
}
|
||||
return s.tx.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (s *sqlTx) Commit() error {
|
||||
return s.tx.Commit()
|
||||
}
|
||||
|
||||
func (s *sqlTx) Rollback() error {
|
||||
return s.tx.Rollback()
|
||||
}
|
Loading…
Reference in New Issue
Block a user