Revert "dbutil: statically require all databases accesses to use contexts"

This reverts commit 8e242cd012.

Revert because lib/pq has known issues with context cancellation.
These issues need to be resolved before these changes can be merged.

Change-Id: I160af51dbc2d67c5449aafa406a403e5367bb555
This commit is contained in:
Egon Elbre 2020-01-15 09:25:26 +02:00
parent c01cbe0130
commit 64fb2d3d2f
41 changed files with 588 additions and 827 deletions

2
go.sum
View File

@ -597,5 +597,7 @@ storj.io/common v0.0.0-20200114152414-8dcdd4c9d250 h1:JDScdUShGqfHyiSWtfXhYEK1Ok
storj.io/common v0.0.0-20200114152414-8dcdd4c9d250/go.mod h1:0yn1ANoDXETNBREGQHq8d7m1Kq0vWMu6Ul7C2YPZo/E=
storj.io/drpc v0.0.7-0.20191115031725-2171c57838d2 h1:8SgLYEhe99R8QlAD1EAOBPRyIR+cn2hqkXtWlAUPf/c=
storj.io/drpc v0.0.7-0.20191115031725-2171c57838d2/go.mod h1:/ascUDbzNAv0A3Jj7wUIKFBH2JdJ2uJIBO/b9+2yHgQ=
storj.io/uplink v0.0.0-20200108132132-c2c5e0d46c1a h1:w/588H+U5IfTXCHA2GTFVLzpUbworS0DtoB4sR9h/8M=
storj.io/uplink v0.0.0-20200108132132-c2c5e0d46c1a/go.mod h1:3498FK1ewiOxrVTbPwGJmE/kwIWA3q9ULtAU/WAreys=
storj.io/uplink v0.0.0-20200109100422-69086b6ee4a8 h1:WG1rX2uc815ZkUz1xrebuZA+JWFBF9Y2n64gvVKZFko=
storj.io/uplink v0.0.0-20200109100422-69086b6ee4a8/go.mod h1:3498FK1ewiOxrVTbPwGJmE/kwIWA3q9ULtAU/WAreys=

View File

@ -16,7 +16,6 @@ import (
"gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbwrap"
)
var mon = monkit.Package()
@ -72,7 +71,7 @@ func OpenUnique(ctx context.Context, connStr string, schemaPrefix string) (db *d
return nil, errs.Combine(errs.Wrap(err), cleanup(masterDB))
}
dbutil.Configure(dbwrap.SQLDB(sqlDB), mon)
dbutil.Configure(sqlDB, mon)
return &dbutil.TempDatabase{
DB: sqlDB,
ConnStr: modifiedConnStr,

View File

@ -4,7 +4,6 @@
package dbschema
import (
"context"
"sort"
"strings"
@ -87,7 +86,7 @@ func QueryData(db Queryer, schema *Schema, quoteColumn func(string) string) (*Da
query := `SELECT ` + strings.Join(quotedColumns, ", ") + ` FROM ` + table.Name
err := func() (err error) {
rows, err := db.QueryContext(context.TODO(), query)
rows, err := db.Query(query)
if err != nil {
return err
}

View File

@ -12,6 +12,8 @@ import (
// Queryer is a representation for something that can query.
type Queryer interface {
// Query executes a query that returns rows, typically a SELECT.
Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryContext executes a query that returns rows, typically a SELECT.
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

View File

@ -1,194 +0,0 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package dbwrap
import (
"context"
"database/sql"
"database/sql/driver"
"time"
"storj.io/storj/pkg/traces"
)
// DB implements a wrapper interface for *sql.DB-like databases which
// require contexts.
type DB interface {
DriverContext(context.Context) driver.Driver
BeginTx(context.Context, *sql.TxOptions) (Tx, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
PrepareContext(ctx context.Context, query string) (Stmt, error)
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
SetConnMaxLifetime(time.Duration)
Stats() sql.DBStats
Conn(ctx context.Context) (Conn, error)
Close() error
}
type sqlDB struct {
DB *sql.DB
}
func (s sqlDB) DriverContext(ctx context.Context) driver.Driver {
traces.Tag(ctx, traces.TagDB)
return s.DB.Driver()
}
func (s sqlDB) Close() error { return s.DB.Close() }
func (s sqlDB) SetMaxIdleConns(n int) { s.DB.SetMaxIdleConns(n) }
func (s sqlDB) SetMaxOpenConns(n int) { s.DB.SetMaxOpenConns(n) }
func (s sqlDB) SetConnMaxLifetime(d time.Duration) { s.DB.SetConnMaxLifetime(d) }
func (s sqlDB) Stats() sql.DBStats { return s.DB.Stats() }
func (s sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
traces.Tag(ctx, traces.TagDB)
tx, err := s.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return sqlTx{Tx: tx}, nil
}
func (s sqlDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
traces.Tag(ctx, traces.TagDB)
return s.DB.ExecContext(ctx, query, args...)
}
func (s sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
traces.Tag(ctx, traces.TagDB)
return s.DB.QueryContext(ctx, query, args...)
}
func (s sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
traces.Tag(ctx, traces.TagDB)
return s.DB.QueryRowContext(ctx, query, args...)
}
func (s sqlDB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
traces.Tag(ctx, traces.TagDB)
return s.DB.PrepareContext(ctx, query)
}
func (s sqlDB) Conn(ctx context.Context) (Conn, error) {
traces.Tag(ctx, traces.TagDB)
conn, err := s.DB.Conn(ctx)
if err != nil {
return nil, err
}
return sqlConn{Conn: conn}, nil
}
// SQLDB turns a *sql.DB into a DB-matching interface
func SQLDB(db *sql.DB) DB {
return sqlDB{DB: db}
}
// Tx is an interface for *sql.Tx-like transactions
type Tx interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
PrepareContext(ctx context.Context, query string) (Stmt, error)
Commit() error
Rollback() error
}
// Conn is an interface for *sql.Conn-like connections
type Conn interface {
BeginTx(context.Context, *sql.TxOptions) (Tx, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
RawContext(ctx context.Context, f func(driverConn interface{}) error) error
Close() error
}
type sqlConn struct {
Conn *sql.Conn
}
func (s sqlConn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
traces.Tag(ctx, traces.TagDB)
tx, err := s.Conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return sqlTx{Tx: tx}, nil
}
func (s sqlConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
traces.Tag(ctx, traces.TagDB)
return s.Conn.ExecContext(ctx, query, args...)
}
func (s sqlConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
traces.Tag(ctx, traces.TagDB)
return s.Conn.QueryContext(ctx, query, args...)
}
func (s sqlConn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
traces.Tag(ctx, traces.TagDB)
return s.Conn.QueryRowContext(ctx, query, args...)
}
func (s sqlConn) RawContext(ctx context.Context, f func(driverConn interface{}) error) error {
traces.Tag(ctx, traces.TagDB)
return s.Conn.Raw(f)
}
func (s sqlConn) Close() error {
return s.Conn.Close()
}
// Stmt is an interface for *sql.Stmt-like prepared statements.
type Stmt interface {
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
Close() error
}
type sqlTx struct {
Tx *sql.Tx
}
func (s sqlTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
traces.Tag(ctx, traces.TagDB)
return s.Tx.ExecContext(ctx, query, args...)
}
func (s sqlTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
traces.Tag(ctx, traces.TagDB)
return s.Tx.QueryContext(ctx, query, args...)
}
func (s sqlTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
traces.Tag(ctx, traces.TagDB)
return s.Tx.QueryRowContext(ctx, query, args...)
}
func (s sqlTx) PrepareContext(ctx context.Context, query string) (Stmt, error) {
traces.Tag(ctx, traces.TagDB)
return s.Tx.PrepareContext(ctx, query)
}
func (s sqlTx) Commit() error { return s.Tx.Commit() }
func (s sqlTx) Rollback() error { return s.Tx.Rollback() }

View File

@ -4,11 +4,10 @@
package dbutil
import (
"database/sql"
"flag"
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil/dbwrap"
)
var (
@ -18,7 +17,7 @@ var (
)
// Configure Sets Connection Boundaries and adds db_stats monitoring to monkit
func Configure(db dbwrap.DB, mon *monkit.Scope) {
func Configure(db *sql.DB, mon *monkit.Scope) {
if *maxIdleConns >= 0 {
db.SetMaxIdleConns(*maxIdleConns)
}

View File

@ -14,7 +14,6 @@ import (
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbschema"
"storj.io/storj/private/dbutil/dbwrap"
)
var (
@ -52,7 +51,7 @@ func OpenUnique(ctx context.Context, connstr string, schemaPrefix string) (*dbut
return DropSchema(ctx, cleanupDB, schemaName)
}
dbutil.Configure(dbwrap.SQLDB(db), mon)
dbutil.Configure(db, mon)
return &dbutil.TempDatabase{
DB: db,
ConnStr: connStrWithSchema,

View File

@ -21,7 +21,7 @@ func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, er
// find tables
err := func() error {
rows, err := db.QueryContext(ctx, `
rows, err := db.Query(`
SELECT table_name, column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA
@ -54,7 +54,7 @@ func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, er
// find constraints
err = func() error {
rows, err := db.QueryContext(ctx, `
rows, err := db.Query(`
SELECT
pg_class.relname AS table_name,
pg_constraint.conname AS constraint_name,

View File

@ -64,13 +64,13 @@ func QuoteSchema(schema string) string {
// Execer is for executing sql
type Execer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Exec(query string, args ...interface{}) (sql.Result, error)
}
// CreateSchema creates a schema if it doesn't exist.
func CreateSchema(ctx context.Context, db Execer, schema string) (err error) {
for try := 0; try < 5; try++ {
_, err = db.ExecContext(ctx, `create schema if not exists `+QuoteSchema(schema)+`;`)
_, err = db.Exec(`create schema if not exists ` + QuoteSchema(schema) + `;`)
// Postgres `CREATE SCHEMA IF NOT EXISTS` may return "duplicate key value violates unique constraint".
// In that case, we will retry rather than doing anything more complicated.
@ -87,6 +87,6 @@ func CreateSchema(ctx context.Context, db Execer, schema string) (err error) {
// DropSchema drops the named schema
func DropSchema(ctx context.Context, db Execer, schema string) error {
_, err := db.ExecContext(ctx, `drop schema `+QuoteSchema(schema)+` cascade;`)
_, err := db.Exec(`drop schema ` + QuoteSchema(schema) + ` cascade;`)
return err
}

View File

@ -9,10 +9,9 @@ import (
"database/sql/driver"
"fmt"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/mattn/go-sqlite3"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/txutil"
"storj.io/storj/private/migrate"
)
@ -20,7 +19,8 @@ import (
// DB is the minimal interface required to perform migrations.
type DB interface {
migrate.DB
Conn(ctx context.Context) (dbwrap.Conn, error)
Conn(ctx context.Context) (*sql.Conn, error)
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
@ -38,9 +38,7 @@ func getSqlite3Conn(conn interface{}) (*sqlite3.SQLiteConn, error) {
switch c := conn.(type) {
case *sqlite3.SQLiteConn:
return c, nil
case interface {
Unwrap() driver.Conn
}:
case interface{ Unwrap() driver.Conn }:
conn = c.Unwrap()
default:
return nil, ErrMigrateTables.New("unable to get raw database connection")
@ -84,13 +82,13 @@ func backupDBs(ctx context.Context, srcDB, destDB DB) error {
// The references to the driver connections are only guaranteed to be valid
// for the life of the callback so we must do the work within both callbacks.
err = srcConn.RawContext(ctx, func(srcDriverConn interface{}) error {
err = srcConn.Raw(func(srcDriverConn interface{}) error {
srcSqliteConn, err := getSqlite3Conn(srcDriverConn)
if err != nil {
return err
}
err = destConn.RawContext(ctx, func(destDriverConn interface{}) error {
err = destConn.Raw(func(destDriverConn interface{}) error {
destSqliteConn, err := getSqlite3Conn(destDriverConn)
if err != nil {
return err
@ -189,7 +187,7 @@ func dropTables(ctx context.Context, db DB, tablesToKeep ...string) (err error)
if err != nil {
return err
}
err = txutil.ExecuteInTx(ctx, db.DriverContext(ctx), tx, func() error {
err = txutil.ExecuteInTx(ctx, db.Driver(), tx, func() error {
// Get a list of tables excluding sqlite3 system tables.
rows, err := tx.QueryContext(ctx, "SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%';")
if err != nil {

View File

@ -12,8 +12,6 @@ import (
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/sqliteutil"
)
@ -87,13 +85,13 @@ func TestKeepTables(t *testing.T) {
require.Equal(t, snapshot.Data, data)
}
func execSQL(ctx context.Context, t *testing.T, db dbwrap.DB, query string, args ...interface{}) {
func execSQL(ctx context.Context, t *testing.T, db *sql.DB, query string, args ...interface{}) {
_, err := db.ExecContext(ctx, query, args...)
require.NoError(t, err)
}
func newMemDB(t *testing.T) dbwrap.DB {
func newMemDB(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
return dbwrap.SQLDB(db)
return db
}

View File

@ -14,7 +14,6 @@ import (
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil/cockroachutil"
"storj.io/storj/private/dbutil/dbwrap"
)
// txLike is the minimal interface for transaction-like objects to work with the necessary retry
@ -52,12 +51,12 @@ func ExecuteInTx(ctx context.Context, dbDriver driver.Driver, tx txLike, fn func
//
// If fn has any side effects outside of changes to the database, they must be idempotent! fn may
// be called more than one time.
func WithTx(ctx context.Context, db dbwrap.DB, txOpts *sql.TxOptions, fn func(context.Context, dbwrap.Tx) error) error {
func WithTx(ctx context.Context, db *sql.DB, txOpts *sql.TxOptions, fn func(context.Context, *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, txOpts)
if err != nil {
return err
}
return ExecuteInTx(ctx, db.DriverContext(ctx), tx, func() error {
return ExecuteInTx(ctx, db.Driver(), tx, func() error {
return fn(ctx, tx)
})
}

View File

@ -8,8 +8,6 @@ import (
"database/sql"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil/dbwrap"
)
// Error is the default migrate errs class
@ -21,27 +19,27 @@ func Create(ctx context.Context, identifier string, db DBX) error {
// when the schemas match.
justRollbackPlease := errs.Class("only used to tell WithTx to do a rollback")
err := WithTx(ctx, db, func(ctx context.Context, tx dbwrap.Tx) (err error) {
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) (err error) {
schema := db.Schema()
_, err = tx.ExecContext(ctx, db.Rebind(`CREATE TABLE IF NOT EXISTS table_schemas (id text, schemaText text);`))
_, err = tx.Exec(db.Rebind(`CREATE TABLE IF NOT EXISTS table_schemas (id text, schemaText text);`))
if err != nil {
return err
}
row := tx.QueryRowContext(ctx, db.Rebind(`SELECT schemaText FROM table_schemas WHERE id = ?;`), identifier)
row := tx.QueryRow(db.Rebind(`SELECT schemaText FROM table_schemas WHERE id = ?;`), identifier)
var previousSchema string
err = row.Scan(&previousSchema)
// not created yet
if err == sql.ErrNoRows {
_, err := tx.ExecContext(ctx, schema)
_, err := tx.Exec(schema)
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, db.Rebind(`INSERT INTO table_schemas(id, schemaText) VALUES (?, ?);`), identifier, schema)
_, err = tx.Exec(db.Rebind(`INSERT INTO table_schemas(id, schemaText) VALUES (?, ?);`), identifier, schema)
if err != nil {
return err
}

View File

@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil/pgtest"
"storj.io/storj/private/dbutil/tempdb"
"storj.io/storj/private/migrate"
@ -24,12 +23,11 @@ func TestCreate_Sqlite(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
rdb, err := sql.Open("sqlite3", ":memory:")
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer func() { assert.NoError(t, rdb.Close()) }()
db := dbwrap.SQLDB(rdb)
defer func() { assert.NoError(t, db.Close()) }()
// should create table
err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
@ -69,33 +67,31 @@ func TestCreate_Cockroach(t *testing.T) {
}
func testCreateGeneric(ctx *testcontext.Context, t *testing.T, connStr string) {
rdb, err := tempdb.OpenUnique(ctx, connStr, "create-")
db, err := tempdb.OpenUnique(ctx, connStr, "create-")
if err != nil {
t.Fatal(err)
}
defer func() { assert.NoError(t, rdb.Close()) }()
db := dbwrap.SQLDB(rdb.DB)
defer func() { assert.NoError(t, db.Close()) }()
// should create table
err = migrate.Create(ctx, "example", &postgresDB{db, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// shouldn't create a new table
err = migrate.Create(ctx, "example", &postgresDB{db, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// should fail, because schema changed
err = migrate.Create(ctx, "example", &postgresDB{db, "CREATE TABLE example_table (id text, version integer)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err)
// should fail, because of trying to CREATE TABLE with same name
err = migrate.Create(ctx, "conflict", &postgresDB{db, "CREATE TABLE example_table (id text, version integer)"})
err = migrate.Create(ctx, "conflict", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err)
}
type sqliteDB struct {
dbwrap.DB
*sql.DB
schema string
}
@ -103,7 +99,7 @@ func (db *sqliteDB) Rebind(s string) string { return s }
func (db *sqliteDB) Schema() string { return db.schema }
type postgresDB struct {
dbwrap.DB
*sql.DB
schema string
}

View File

@ -8,7 +8,6 @@ import (
"database/sql"
"database/sql/driver"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/txutil"
)
@ -16,8 +15,8 @@ import (
//
// DB can optionally have `Rebind(string) string` for translating `? queries for the specific database.
type DB interface {
BeginTx(ctx context.Context, txOptions *sql.TxOptions) (dbwrap.Tx, error)
DriverContext(context.Context) driver.Driver
BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error)
Driver() driver.Driver
}
// DBX contains additional methods for migrations.
@ -29,21 +28,19 @@ type DBX interface {
// rebind uses Rebind method when the database has the func.
func rebind(db DB, s string) string {
if dbx, ok := db.(interface {
Rebind(string) string
}); ok {
if dbx, ok := db.(interface{ Rebind(string) string }); ok {
return dbx.Rebind(s)
}
return s
}
// WithTx runs the given callback in the context of a transaction.
func WithTx(ctx context.Context, db DB, fn func(ctx context.Context, tx dbwrap.Tx) error) error {
func WithTx(ctx context.Context, db DB, fn func(ctx context.Context, tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
return txutil.ExecuteInTx(ctx, db.DriverContext(ctx), tx, func() error {
return txutil.ExecuteInTx(ctx, db.Driver(), tx, func() error {
return fn(ctx, tx)
})
}

View File

@ -13,8 +13,6 @@ import (
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/storj/private/dbutil/dbwrap"
)
var (
@ -67,7 +65,7 @@ type Step struct {
// Action is something that needs to be done
type Action interface {
Run(ctx context.Context, log *zap.Logger, db DB, tx dbwrap.Tx) error
Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
}
// TargetVersion returns migration with steps upto specified version
@ -166,7 +164,7 @@ func (migration *Migration) Run(ctx context.Context, log *zap.Logger) error {
stepLog.Info(step.Description)
}
err = WithTx(ctx, step.DB, func(ctx context.Context, tx dbwrap.Tx) error {
err = WithTx(ctx, step.DB, func(ctx context.Context, tx *sql.Tx) error {
err = step.Action.Run(ctx, stepLog, step.DB, tx)
if err != nil {
return err
@ -199,8 +197,8 @@ func (migration *Migration) Run(ctx context.Context, log *zap.Logger) error {
// createVersionTable creates a new version table
func (migration *Migration) ensureVersionTable(ctx context.Context, log *zap.Logger, db DB) error {
err := WithTx(ctx, db, func(ctx context.Context, tx dbwrap.Tx) error {
_, err := tx.ExecContext(ctx, rebind(db, `CREATE TABLE IF NOT EXISTS `+migration.Table+` (version int, commited_at text)`)) //nolint:misspell
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(rebind(db, `CREATE TABLE IF NOT EXISTS `+migration.Table+` (version int, commited_at text)`)) //nolint:misspell
return err
})
return Error.Wrap(err)
@ -209,8 +207,8 @@ func (migration *Migration) ensureVersionTable(ctx context.Context, log *zap.Log
// getLatestVersion finds the latest version table
func (migration *Migration) getLatestVersion(ctx context.Context, log *zap.Logger, db DB) (int, error) {
var version sql.NullInt64
err := WithTx(ctx, db, func(ctx context.Context, tx dbwrap.Tx) error {
err := tx.QueryRowContext(ctx, rebind(db, `SELECT MAX(version) FROM `+migration.Table)).Scan(&version)
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow(rebind(db, `SELECT MAX(version) FROM `+migration.Table)).Scan(&version)
if err == sql.ErrNoRows || !version.Valid {
version.Int64 = -1
return nil
@ -222,8 +220,8 @@ func (migration *Migration) getLatestVersion(ctx context.Context, log *zap.Logge
}
// addVersion adds information about a new migration
func (migration *Migration) addVersion(ctx context.Context, tx dbwrap.Tx, db DB, version int) error {
_, err := tx.ExecContext(ctx, rebind(db, `
func (migration *Migration) addVersion(ctx context.Context, tx *sql.Tx, db DB, version int) error {
_, err := tx.Exec(rebind(db, `
INSERT INTO `+migration.Table+` (version, commited_at) VALUES (?, ?)`), //nolint:misspell
version, time.Now().String(),
)
@ -243,9 +241,9 @@ func (migration *Migration) CurrentVersion(ctx context.Context, log *zap.Logger,
type SQL []string
// Run runs the SQL statements
func (sql SQL) Run(ctx context.Context, log *zap.Logger, db DB, tx dbwrap.Tx) (err error) {
func (sql SQL) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) (err error) {
for _, query := range sql {
_, err := tx.ExecContext(ctx, rebind(db, query))
_, err := tx.Exec(rebind(db, query))
if err != nil {
return err
}
@ -254,9 +252,9 @@ func (sql SQL) Run(ctx context.Context, log *zap.Logger, db DB, tx dbwrap.Tx) (e
}
// Func is an arbitrary operation
type Func func(ctx context.Context, log *zap.Logger, db DB, tx dbwrap.Tx) error
type Func func(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
// Run runs the migration
func (fn Func) Run(ctx context.Context, log *zap.Logger, db DB, tx dbwrap.Tx) error {
func (fn Func) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error {
return fn(ctx, log, db, tx)
}

View File

@ -18,7 +18,7 @@ import (
"go.uber.org/zap"
"storj.io/common/testcontext"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil/pgtest"
"storj.io/storj/private/dbutil/tempdb"
"storj.io/storj/private/migrate"
@ -31,7 +31,7 @@ func TestBasicMigrationSqliteNoRebind(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
basicMigration(ctx, t, dbwrap.SQLDB(db), dbwrap.SQLDB(db))
basicMigration(ctx, t, db, db)
}
func TestBasicMigrationSqlite(t *testing.T) {
@ -41,7 +41,7 @@ func TestBasicMigrationSqlite(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
basicMigration(ctx, t, dbwrap.SQLDB(db), &sqliteDB{DB: dbwrap.SQLDB(db)})
basicMigration(ctx, t, db, &sqliteDB{DB: db})
}
func TestBasicMigrationPostgres(t *testing.T) {
@ -69,12 +69,12 @@ func testBasicMigrationGeneric(ctx *testcontext.Context, t *testing.T, connStr s
}
defer func() { assert.NoError(t, db.Close()) }()
basicMigration(ctx, t, dbwrap.SQLDB(db.DB), &postgresDB{DB: dbwrap.SQLDB(db.DB)})
basicMigration(ctx, t, db.DB, &postgresDB{DB: db.DB})
}
func basicMigration(ctx *testcontext.Context, t *testing.T, db dbwrap.DB, testDB migrate.DB) {
func basicMigration(ctx *testcontext.Context, t *testing.T, db *sql.DB, testDB migrate.DB) {
dbName := strings.ToLower(`versions_` + t.Name())
defer func() { assert.NoError(t, dropTables(ctx, db, dbName, "users")) }()
defer func() { assert.NoError(t, dropTables(db, dbName, "users")) }()
err := ioutil.WriteFile(ctx.File("alpha.txt"), []byte("test"), 0644)
require.NoError(t, err)
@ -94,7 +94,7 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db dbwrap.DB, testDB
DB: testDB,
Description: "Move files",
Version: 2,
Action: migrate.Func(func(_ context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(_ context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return os.Rename(ctx.File("alpha.txt"), ctx.File("beta.txt"))
}),
},
@ -126,12 +126,12 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db dbwrap.DB, testDB
assert.Equal(t, dbVersion, 2)
var version int
err = db.QueryRowContext(ctx, `SELECT MAX(version) FROM `+dbName).Scan(&version)
err = db.QueryRow(`SELECT MAX(version) FROM ` + dbName).Scan(&version)
assert.NoError(t, err)
assert.Equal(t, 2, version)
var id int
err = db.QueryRowContext(ctx, `SELECT MAX(id) FROM users`).Scan(&id)
err = db.QueryRow(`SELECT MAX(id) FROM users`).Scan(&id)
assert.NoError(t, err)
assert.Equal(t, 1, id)
@ -152,7 +152,7 @@ func TestMultipleMigrationSqlite(t *testing.T) {
require.NoError(t, err)
defer func() { assert.NoError(t, db.Close()) }()
multipleMigration(t, dbwrap.SQLDB(db), &sqliteDB{DB: dbwrap.SQLDB(db)})
multipleMigration(t, db, &sqliteDB{DB: db})
}
func TestMultipleMigrationPostgres(t *testing.T) {
@ -164,15 +164,15 @@ func TestMultipleMigrationPostgres(t *testing.T) {
require.NoError(t, err)
defer func() { assert.NoError(t, db.Close()) }()
multipleMigration(t, dbwrap.SQLDB(db), &postgresDB{DB: dbwrap.SQLDB(db)})
multipleMigration(t, db, &postgresDB{DB: db})
}
func multipleMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
dbName := strings.ToLower(`versions_` + t.Name())
defer func() { assert.NoError(t, dropTables(ctx, db, dbName)) }()
defer func() { assert.NoError(t, dropTables(db, dbName)) }()
steps := 0
m := migrate.Migration{
@ -182,7 +182,7 @@ func multipleMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 1",
Version: 1,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
@ -191,7 +191,7 @@ func multipleMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 2",
Version: 2,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
@ -207,7 +207,7 @@ func multipleMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 3",
Version: 3,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
@ -216,7 +216,7 @@ func multipleMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
assert.NoError(t, err)
var version int
err = db.QueryRowContext(ctx, `SELECT MAX(version) FROM `+dbName).Scan(&version)
err = db.QueryRow(`SELECT MAX(version) FROM ` + dbName).Scan(&version)
assert.NoError(t, err)
assert.Equal(t, 3, version)
@ -228,7 +228,7 @@ func TestFailedMigrationSqlite(t *testing.T) {
require.NoError(t, err)
defer func() { assert.NoError(t, db.Close()) }()
failedMigration(t, dbwrap.SQLDB(db), &sqliteDB{DB: dbwrap.SQLDB(db)})
failedMigration(t, db, &sqliteDB{DB: db})
}
func TestFailedMigrationPostgres(t *testing.T) {
@ -240,15 +240,15 @@ func TestFailedMigrationPostgres(t *testing.T) {
require.NoError(t, err)
defer func() { assert.NoError(t, db.Close()) }()
failedMigration(t, dbwrap.SQLDB(db), &postgresDB{DB: dbwrap.SQLDB(db)})
failedMigration(t, db, &postgresDB{DB: db})
}
func failedMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
func failedMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
dbName := strings.ToLower(`versions_` + t.Name())
defer func() { assert.NoError(t, dropTables(ctx, db, dbName)) }()
defer func() { assert.NoError(t, dropTables(db, dbName)) }()
m := migrate.Migration{
Table: dbName,
@ -257,7 +257,7 @@ func failedMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 1",
Version: 1,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return fmt.Errorf("migration failed")
}),
},
@ -268,7 +268,7 @@ func failedMigration(t *testing.T, db dbwrap.DB, testDB migrate.DB) {
require.Error(t, err, "migration failed")
var version sql.NullInt64
err = db.QueryRowContext(ctx, `SELECT MAX(version) FROM `+dbName).Scan(&version)
err = db.QueryRow(`SELECT MAX(version) FROM ` + dbName).Scan(&version)
assert.NoError(t, err)
assert.Equal(t, false, version.Valid)
}
@ -326,10 +326,10 @@ func TestInvalidStepsOrder(t *testing.T) {
require.Error(t, err, "migrate: steps have incorrect order")
}
func dropTables(ctx context.Context, db dbwrap.DB, names ...string) error {
func dropTables(db *sql.DB, names ...string) error {
var errlist errs.Group
for _, name := range names {
_, err := db.ExecContext(ctx, `DROP TABLE `+name)
_, err := db.Exec(`DROP TABLE ` + name)
errlist.Add(err)
}

View File

@ -16,7 +16,10 @@ import (
_ "storj.io/storj/private/dbutil/cockroachutil"
)
//go:generate bash gen.sh
//go:generate dbx schema -d postgres -d cockroach satellitedb.dbx .
//go:generate dbx golang -d postgres -d cockroach -t templates satellitedb.dbx .
//go:generate bash -c "( echo '//lint:file-ignore * generated file'; cat satellitedb.dbx.go ) > satellitedb.dbx.go.tmp && mv satellitedb.dbx.go{.tmp,}"
//go:generate perl -p0i -e "s,^(\\s*\"github.com/lib/pq\")\\n\\n\\1,\\1,gm" satellitedb.dbx.go
var mon = monkit.Package()
@ -62,7 +65,7 @@ func (db *DB) WithTx(ctx context.Context, fn func(context.Context, *Tx) error) (
if err != nil {
return err
}
return txutil.ExecuteInTx(ctx, db.DriverContext(ctx), tx.Tx, func() error {
return txutil.ExecuteInTx(ctx, db.Driver(), tx.Tx, func() error {
return fn(ctx, tx)
})
}

View File

@ -1,15 +0,0 @@
#!/bin/bash
dbx schema -d postgres -d cockroach satellitedb.dbx .
dbx golang -d postgres -d cockroach -t templates satellitedb.dbx .
( echo '//lint:file-ignore * generated file'; cat satellitedb.dbx.go ) > satellitedb.dbx.go.tmp && mv satellitedb.dbx.go{.tmp,}
gofmt -r "*sql.Tx -> dbwrap.Tx" -w satellitedb.dbx.go
perl -0777 -pi \
-e 's,\t"github.com/lib/pq"\n\),\t"github.com/lib/pq"\n\n\t"storj.io/storj/private/dbutil/dbwrap"\n\),' \
satellitedb.dbx.go
perl -0777 -pi \
-e 's/type DB struct \{\n\t\*sql\.DB/type DB struct \{\n\tdbwrap.DB/' \
satellitedb.dbx.go
perl -0777 -pi \
-e 's/\tdb = &DB\{\n\t\tDB: sql_db,/\tdb = &DB\{\n\t\tDB: dbwrap.SQLDB\(sql_db\),/' \
satellitedb.dbx.go

View File

@ -540,7 +540,7 @@ model storagenode_bandwidth_rollup (
field interval_seconds uint
field action uint
field allocated uint64 ( updatable, nullable )
field allocated uint64 ( updatable )
field settled uint64 ( updatable )
)

View File

@ -221,7 +221,7 @@ CREATE TABLE storagenode_bandwidth_rollups (
interval_start timestamp NOT NULL,
interval_seconds integer NOT NULL,
action integer NOT NULL,
allocated bigint,
allocated bigint DEFAULT 0,
settled bigint NOT NULL,
PRIMARY KEY ( storagenode_id, interval_start, action )
);

File diff suppressed because it is too large Load Diff

View File

@ -221,7 +221,7 @@ CREATE TABLE storagenode_bandwidth_rollups (
interval_start timestamp NOT NULL,
interval_seconds integer NOT NULL,
action integer NOT NULL,
allocated bigint,
allocated bigint DEFAULT 0,
settled bigint NOT NULL,
PRIMARY KEY ( storagenode_id, interval_start, action )
);

View File

@ -29,7 +29,7 @@ func (db *gracefulexitDB) IncrementProgress(ctx context.Context, nodeID storj.No
statement := db.db.Rebind(
`INSERT INTO graceful_exit_progress (node_id, bytes_transferred, pieces_transferred, pieces_failed, updated_at) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(node_id)
DO UPDATE SET bytes_transferred = graceful_exit_progress.bytes_transferred + excluded.bytes_transferred,
DO UPDATE SET bytes_transferred = graceful_exit_progress.bytes_transferred + excluded.bytes_transferred,
pieces_transferred = graceful_exit_progress.pieces_transferred + excluded.pieces_transferred,
pieces_failed = graceful_exit_progress.pieces_failed + excluded.pieces_failed,
updated_at = excluded.updated_at;`,
@ -174,11 +174,11 @@ func (db *gracefulexitDB) GetTransferQueueItem(ctx context.Context, nodeID storj
func (db *gracefulexitDB) GetIncomplete(ctx context.Context, nodeID storj.NodeID, limit int, offset int64) (_ []*gracefulexit.TransferQueueItem, err error) {
defer mon.Task()(&ctx)(&err)
sql := `SELECT node_id, path, piece_num, root_piece_id, durability_ratio, queued_at, requested_at, last_failed_at, last_failed_code, failed_count, finished_at, order_limit_send_count
FROM graceful_exit_transfer_queue
WHERE node_id = ?
AND finished_at is NULL
FROM graceful_exit_transfer_queue
WHERE node_id = ?
AND finished_at is NULL
ORDER BY durability_ratio asc, queued_at asc LIMIT ? OFFSET ?`
rows, err := db.db.QueryContext(ctx, db.db.Rebind(sql), nodeID.Bytes(), limit, offset)
rows, err := db.db.Query(db.db.Rebind(sql), nodeID.Bytes(), limit, offset)
if err != nil {
return nil, Error.Wrap(err)
}
@ -199,12 +199,12 @@ func (db *gracefulexitDB) GetIncomplete(ctx context.Context, nodeID storj.NodeID
func (db *gracefulexitDB) GetIncompleteNotFailed(ctx context.Context, nodeID storj.NodeID, limit int, offset int64) (_ []*gracefulexit.TransferQueueItem, err error) {
defer mon.Task()(&ctx)(&err)
sql := `SELECT node_id, path, piece_num, root_piece_id, durability_ratio, queued_at, requested_at, last_failed_at, last_failed_code, failed_count, finished_at, order_limit_send_count
FROM graceful_exit_transfer_queue
WHERE node_id = ?
FROM graceful_exit_transfer_queue
WHERE node_id = ?
AND finished_at is NULL
AND last_failed_at is NULL
ORDER BY durability_ratio asc, queued_at asc LIMIT ? OFFSET ?`
rows, err := db.db.QueryContext(ctx, db.db.Rebind(sql), nodeID.Bytes(), limit, offset)
rows, err := db.db.Query(db.db.Rebind(sql), nodeID.Bytes(), limit, offset)
if err != nil {
return nil, Error.Wrap(err)
}
@ -225,13 +225,13 @@ func (db *gracefulexitDB) GetIncompleteNotFailed(ctx context.Context, nodeID sto
func (db *gracefulexitDB) GetIncompleteFailed(ctx context.Context, nodeID storj.NodeID, maxFailures int, limit int, offset int64) (_ []*gracefulexit.TransferQueueItem, err error) {
defer mon.Task()(&ctx)(&err)
sql := `SELECT node_id, path, piece_num, root_piece_id, durability_ratio, queued_at, requested_at, last_failed_at, last_failed_code, failed_count, finished_at, order_limit_send_count
FROM graceful_exit_transfer_queue
WHERE node_id = ?
FROM graceful_exit_transfer_queue
WHERE node_id = ?
AND finished_at is NULL
AND last_failed_at is not NULL
AND failed_count < ?
ORDER BY durability_ratio asc, queued_at asc LIMIT ? OFFSET ?`
rows, err := db.db.QueryContext(ctx, db.db.Rebind(sql), nodeID.Bytes(), maxFailures, limit, offset)
rows, err := db.db.Query(db.db.Rebind(sql), nodeID.Bytes(), maxFailures, limit, offset)
if err != nil {
return nil, Error.Wrap(err)
}

View File

@ -45,11 +45,11 @@ func (db *satelliteDB) CreateTables(ctx context.Context) error {
case dbutil.Cockroach:
var dbName string
if err := db.QueryRowContext(ctx, `SELECT current_database();`).Scan(&dbName); err != nil {
if err := db.QueryRow(`SELECT current_database();`).Scan(&dbName); err != nil {
return errs.New("error querying current database: %+v", err)
}
_, err := db.ExecContext(ctx, fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %s;`,
_, err := db.Exec(fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %s;`,
pq.QuoteIdentifier(dbName)))
if err != nil {
return errs.Wrap(err)

View File

@ -184,7 +184,7 @@ func pgMigrateTest(t *testing.T, connStr string) {
// insert data for new tables
if newdata := newData(expected); newdata != "" {
_, err = rawdb.ExecContext(ctx, newdata)
_, err = rawdb.Exec(newdata)
require.NoError(t, err, tag)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/zeebo/errs"
"storj.io/storj/private/currency"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/txutil"
"storj.io/storj/satellite/rewards"
dbx "storj.io/storj/satellite/satellitedb/dbx"
@ -109,7 +108,7 @@ func (db *offersDB) Create(ctx context.Context, o *rewards.NewOffer) (*rewards.O
var id int64
err := txutil.WithTx(ctx, db.db.DB.DB, nil, func(ctx context.Context, tx dbwrap.Tx) error {
err := txutil.WithTx(ctx, db.db.DB.DB, nil, func(ctx context.Context, tx *sql.Tx) error {
// If there's an existing current offer, update its status to Done and set its expires_at to be NOW()
switch o.Type {
case rewards.Partner:
@ -131,7 +130,7 @@ func (db *offersDB) Create(ctx context.Context, o *rewards.NewOffer) (*rewards.O
}
}
statement := `
INSERT INTO offers (name, description, award_credit_in_cents, invitee_credit_in_cents, award_credit_duration_days,
INSERT INTO offers (name, description, award_credit_in_cents, invitee_credit_in_cents, award_credit_duration_days,
invitee_credit_duration_days, redeemable_cap, expires_at, created_at, status, type)
VALUES (?::TEXT, ?::TEXT, ?::INT, ?::INT, ?::INT, ?::INT, ?::INT, ?::timestamptz, ?::timestamptz, ?::INT, ?::INT)
RETURNING id;

View File

@ -18,7 +18,6 @@ import (
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil"
"storj.io/storj/satellite/orders"
dbx "storj.io/storj/satellite/satellitedb/dbx"
@ -162,7 +161,7 @@ func (db *ordersDB) GetBucketBandwidth(ctx context.Context, projectID uuid.UUID,
defer mon.Task()(&ctx)(&err)
var sum *int64
query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRowContext(ctx, db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
@ -174,7 +173,7 @@ func (db *ordersDB) GetStorageNodeBandwidth(ctx context.Context, nodeID storj.No
defer mon.Task()(&ctx)(&err)
var sum *int64
query := `SELECT SUM(settled) FROM storagenode_bandwidth_rollups WHERE storagenode_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRowContext(ctx, db.db.Rebind(query), nodeID.Bytes(), from, to).Scan(&sum)
err = db.db.QueryRow(db.db.Rebind(query), nodeID.Bytes(), from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
@ -220,9 +219,7 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
return responses, errs.Wrap(err)
}
func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx dbwrap.Tx) (responses []*orders.ProcessOrderResponse, err error) {
defer mon.Task()(&ctx)(&err)
func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx *sql.Tx) (responses []*orders.ProcessOrderResponse, err error) {
now = now.UTC()
intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
@ -231,7 +228,7 @@ func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.Pr
// load the bucket id and insert into used serials table
for _, request := range requests {
row := tx.QueryRowContext(ctx, db.db.Rebind(`
row := tx.QueryRow(db.db.Rebind(`
SELECT id, bucket_id
FROM serial_numbers
WHERE serial_number = ?
@ -250,7 +247,7 @@ func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.Pr
var count int64
// try to insert the serial number
result, err = tx.ExecContext(ctx, db.db.Rebind(`
result, err = tx.Exec(db.db.Rebind(`
INSERT INTO used_serials(serial_number_id, storage_node_id)
VALUES (?, ?)
ON CONFLICT DO NOTHING
@ -293,7 +290,7 @@ func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.Pr
continue
}
_, err := tx.ExecContext(ctx, db.db.Rebind(`
_, err := tx.Exec(db.db.Rebind(`
INSERT INTO storagenode_bandwidth_rollups
(storagenode_id, interval_start, interval_seconds, action, settled)
VALUES (?, ?, ?, ?, ?)
@ -340,7 +337,7 @@ func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.Pr
return nil, errs.Wrap(err)
}
_, err = tx.ExecContext(ctx, db.db.Rebind(`
_, err = tx.Exec(db.db.Rebind(`
INSERT INTO bucket_bandwidth_rollups
(bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)

View File

@ -149,7 +149,7 @@ func (cache *overlaycache) GetNodeIPs(ctx context.Context, nodeIDs []storj.NodeI
defer mon.Task()(&ctx)(&err)
var rows *sql.Rows
rows, err = cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err = cache.db.Query(cache.db.Rebind(`
SELECT last_net FROM nodes
WHERE id = any($1::bytea[])
`), postgresNodeIDList(nodeIDs),
@ -188,7 +188,7 @@ func (cache *overlaycache) queryNodes(ctx context.Context, excludedNodes []storj
args = append(args, count)
var rows *sql.Rows
rows, err = cache.db.QueryContext(ctx, cache.db.Rebind(`SELECT id, type, address, last_net,
rows, err = cache.db.Query(cache.db.Rebind(`SELECT id, type, address, last_net,
free_bandwidth, free_disk, total_audit_count, audit_success_count,
total_uptime_count, uptime_success_count, disqualified, audit_reputation_alpha,
audit_reputation_beta
@ -248,7 +248,7 @@ func (cache *overlaycache) queryNodesDistinct(ctx context.Context, excludedNodes
}
args = append(args, count)
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT *
FROM (
SELECT DISTINCT ON (last_net) last_net, -- choose at max 1 node from this IP or network
@ -318,7 +318,7 @@ func (cache *overlaycache) KnownOffline(ctx context.Context, criteria *overlay.N
// get offline nodes
var rows *sql.Rows
rows, err = cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err = cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE id = any($1::bytea[])
AND (
@ -352,7 +352,7 @@ func (cache *overlaycache) KnownUnreliableOrOffline(ctx context.Context, criteri
// get reliable and online nodes
var rows *sql.Rows
rows, err = cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err = cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE id = any($1::bytea[])
AND disqualified IS NULL
@ -390,7 +390,7 @@ func (cache *overlaycache) KnownReliable(ctx context.Context, onlineWindow time.
}
// get online nodes
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT id, last_net, address, protocol FROM nodes
WHERE id = any($1::bytea[])
AND disqualified IS NULL
@ -419,10 +419,8 @@ func (cache *overlaycache) KnownReliable(ctx context.Context, onlineWindow time.
// Reliable returns all reliable nodes.
func (cache *overlaycache) Reliable(ctx context.Context, criteria *overlay.NodeCriteria) (nodes storj.NodeIDList, err error) {
defer mon.Task()(&ctx)(&err)
// get reliable and online nodes
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE disqualified IS NULL
AND last_contact_success > ?`),
@ -620,7 +618,7 @@ func (cache *overlaycache) BatchUpdateStats(ctx context.Context, updateRequests
}
if allSQL != "" {
results, err := tx.Tx.ExecContext(ctx, allSQL)
results, err := tx.Tx.Exec(allSQL)
if err != nil {
return err
}
@ -895,7 +893,7 @@ func (cache *overlaycache) UpdatePieceCounts(ctx context.Context, pieceCounts ma
func (cache *overlaycache) GetExitingNodes(ctx context.Context) (exitingNodes []*overlay.ExitStatus, err error) {
defer mon.Task()(&ctx)(&err)
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT id, exit_initiated_at, exit_loop_completed_at, exit_finished_at, exit_success FROM nodes
WHERE exit_initiated_at IS NOT NULL
AND exit_finished_at IS NULL
@ -924,7 +922,7 @@ func (cache *overlaycache) GetExitingNodes(ctx context.Context) (exitingNodes []
func (cache *overlaycache) GetExitStatus(ctx context.Context, nodeID storj.NodeID) (_ *overlay.ExitStatus, err error) {
defer mon.Task()(&ctx)(&err)
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind("select id, exit_initiated_at, exit_loop_completed_at, exit_finished_at, exit_success from nodes where id = ?"), nodeID)
rows, err := cache.db.Query(cache.db.Rebind("select id, exit_initiated_at, exit_loop_completed_at, exit_finished_at, exit_success from nodes where id = ?"), nodeID)
if err != nil {
return nil, Error.Wrap(err)
}
@ -943,7 +941,7 @@ func (cache *overlaycache) GetExitStatus(ctx context.Context, nodeID storj.NodeI
func (cache *overlaycache) GetGracefulExitCompletedByTimeFrame(ctx context.Context, begin, end time.Time) (exitedNodes storj.NodeIDList, err error) {
defer mon.Task()(&ctx)(&err)
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE exit_initiated_at IS NOT NULL
AND exit_finished_at IS NOT NULL
@ -973,7 +971,7 @@ func (cache *overlaycache) GetGracefulExitCompletedByTimeFrame(ctx context.Conte
func (cache *overlaycache) GetGracefulExitIncompleteByTimeFrame(ctx context.Context, begin, end time.Time) (exitingNodes storj.NodeIDList, err error) {
defer mon.Task()(&ctx)(&err)
rows, err := cache.db.QueryContext(ctx, cache.db.Rebind(`
rows, err := cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE exit_initiated_at IS NOT NULL
AND exit_finished_at IS NULL

View File

@ -83,7 +83,7 @@ func (idents *peerIdentities) BatchGet(ctx context.Context, nodeIDs storj.NodeID
// TODO: optimize using arrays like overlay
rows, err := idents.db.QueryContext(ctx, idents.db.Rebind(`
rows, err := idents.db.Query(idents.db.Rebind(`
SELECT chain FROM peer_identities WHERE node_id IN (?`+strings.Repeat(", ?", len(nodeIDs)-1)+`)`), args...)
if err != nil {
return nil, Error.Wrap(err)

View File

@ -110,7 +110,7 @@ func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, pro
defer mon.Task()(&ctx)(&err)
var sum *int64
query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;`
err = db.db.QueryRowContext(ctx, db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum)
err = db.db.QueryRow(db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
@ -133,7 +133,7 @@ func (db *ProjectAccounting) GetStorageTotals(ctx context.Context, projectID uui
GROUP BY interval_start
ORDER BY interval_start DESC LIMIT 1;`
err = db.db.QueryRowContext(ctx, db.db.Rebind(query), projectID[:]).Scan(&intervalStart, &inlineSum, &remoteSum)
err = db.db.QueryRow(db.db.Rebind(query), projectID[:]).Scan(&intervalStart, &inlineSum, &remoteSum)
if err != nil || !inlineSum.Valid || !remoteSum.Valid {
return 0, 0, nil
}
@ -192,17 +192,17 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
storageQuery := db.db.Rebind(`
SELECT
bucket_storage_tallies.interval_start,
bucket_storage_tallies.interval_start,
bucket_storage_tallies.inline,
bucket_storage_tallies.remote,
bucket_storage_tallies.object_count
FROM
bucket_storage_tallies
WHERE
bucket_storage_tallies.project_id = ? AND
bucket_storage_tallies.bucket_name = ? AND
bucket_storage_tallies.interval_start >= ? AND
bucket_storage_tallies.interval_start <= ?
FROM
bucket_storage_tallies
WHERE
bucket_storage_tallies.project_id = ? AND
bucket_storage_tallies.bucket_name = ? AND
bucket_storage_tallies.interval_start >= ? AND
bucket_storage_tallies.interval_start <= ?
ORDER BY bucket_storage_tallies.interval_start DESC
`)
@ -266,14 +266,14 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
// only process PieceAction_GET, PieceAction_GET_AUDIT, PieceAction_GET_REPAIR actions.
func (db *ProjectAccounting) getTotalEgress(ctx context.Context, projectID uuid.UUID, since, before time.Time) (totalEgress int64, err error) {
totalEgressQuery := db.db.Rebind(fmt.Sprintf(`
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
project_id = ? AND
interval_start >= ? AND
interval_start <= ? AND
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
project_id = ? AND
interval_start >= ? AND
interval_start <= ? AND
action IN (%d, %d, %d);
`, pb.PieceAction_GET, pb.PieceAction_GET_AUDIT, pb.PieceAction_GET_REPAIR))

View File

@ -14,7 +14,6 @@ import (
"gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil"
"storj.io/storj/storage"
"storj.io/storj/storage/cockroachkv/schema"
@ -43,7 +42,7 @@ func New(dbURL string) (*Client, error) {
return nil, err
}
dbutil.Configure(dbwrap.SQLDB(db), mon)
dbutil.Configure(db, mon)
// TODO: new shouldn't be taking ctx as argument
err = schema.PrepareDB(context.TODO(), db)

View File

@ -71,7 +71,7 @@ func AltNew(dbURL string) (*AlternateClient, error) {
if err != nil {
return nil, err
}
_, err = client.pgConn.ExecContext(context.TODO(), alternateSQLSetup)
_, err = client.pgConn.Exec(alternateSQLSetup)
if err != nil {
return nil, errs.Combine(err, client.Close())
}
@ -80,7 +80,7 @@ func AltNew(dbURL string) (*AlternateClient, error) {
// Close closes an AlternateClient and frees its resources.
func (altClient *AlternateClient) Close() error {
_, err := altClient.pgConn.ExecContext(context.TODO(), alternateSQLTeardown)
_, err := altClient.pgConn.Exec(alternateSQLTeardown)
return errs.Combine(err, altClient.Client.Close())
}
@ -97,7 +97,7 @@ func (opi *alternateOrderedPostgresIterator) doNextQuery(ctx context.Context) (_
if start == nil {
start = opi.opts.First
}
return opi.client.pgConn.QueryContext(ctx, alternateForwardQuery, []byte(opi.bucket), []byte(opi.opts.Prefix), []byte(start), opi.batchSize+1)
return opi.client.pgConn.Query(alternateForwardQuery, []byte(opi.bucket), []byte(opi.opts.Prefix), []byte(start), opi.batchSize+1)
}
func newAlternateOrderedPostgresIterator(ctx context.Context, altClient *AlternateClient, opts storage.IterateOptions, batchSize int) (_ *alternateOrderedPostgresIterator, err error) {

View File

@ -12,7 +12,6 @@ import (
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/storage"
"storj.io/storj/storage/postgreskv/schema"
)
@ -29,7 +28,7 @@ var (
// Client is the entrypoint into a postgreskv data store
type Client struct {
URL string
pgConn dbwrap.DB
pgConn *sql.DB
}
// New instantiates a new postgreskv client given db URL
@ -39,7 +38,7 @@ func New(dbURL string) (*Client, error) {
return nil, err
}
dbutil.Configure(dbwrap.SQLDB(pgConn), mon)
dbutil.Configure(pgConn, mon)
// TODO: this probably should not happen in constructor
err = schema.PrepareDB(context.TODO(), pgConn, dbURL)
@ -49,7 +48,7 @@ func New(dbURL string) (*Client, error) {
return &Client{
URL: dbURL,
pgConn: dbwrap.SQLDB(pgConn),
pgConn: pgConn,
}, nil
}
@ -70,7 +69,7 @@ func (client *Client) PutPath(ctx context.Context, bucket, key storage.Key, valu
VALUES ($1::BYTEA, $2::BYTEA, $3::BYTEA)
ON CONFLICT (bucket, fullpath) DO UPDATE SET metadata = EXCLUDED.metadata
`
_, err = client.pgConn.ExecContext(ctx, q, []byte(bucket), []byte(key), []byte(value))
_, err = client.pgConn.Exec(q, []byte(bucket), []byte(key), []byte(value))
return err
}
@ -88,7 +87,7 @@ func (client *Client) GetPath(ctx context.Context, bucket, key storage.Key) (_ s
}
q := "SELECT metadata FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA"
row := client.pgConn.QueryRowContext(ctx, q, []byte(bucket), []byte(key))
row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key))
var val []byte
err = row.Scan(&val)
@ -113,7 +112,7 @@ func (client *Client) DeletePath(ctx context.Context, bucket, key storage.Key) (
}
q := "DELETE FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA"
result, err := client.pgConn.ExecContext(ctx, q, []byte(bucket), []byte(key))
result, err := client.pgConn.Exec(q, []byte(bucket), []byte(key))
if err != nil {
return err
}
@ -162,7 +161,7 @@ func (client *Client) GetAllPath(ctx context.Context, bucket storage.Key, keys s
ON (pd.fullpath = pk.request AND pd.bucket = $1::BYTEA)
ORDER BY pk.ord
`
rows, err := client.pgConn.QueryContext(ctx, q, []byte(bucket), pq.ByteaArray(keys.ByteSlices()))
rows, err := client.pgConn.Query(q, []byte(bucket), pq.ByteaArray(keys.ByteSlices()))
if err != nil {
return nil, errs.Wrap(err)
}
@ -263,7 +262,7 @@ func (opi *orderedPostgresIterator) doNextQuery(ctx context.Context) (_ *sql.Row
LIMIT $4
`
}
return opi.client.pgConn.QueryContext(ctx, query, []byte(opi.bucket), []byte(opi.opts.Prefix), []byte(start), opi.batchSize+1)
return opi.client.pgConn.Query(query, []byte(opi.bucket), []byte(opi.opts.Prefix), []byte(start), opi.batchSize+1)
}
func (opi *orderedPostgresIterator) Close() error {
@ -324,7 +323,7 @@ func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storag
if oldValue == nil && newValue == nil {
q := "SELECT metadata FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA"
row := client.pgConn.QueryRowContext(ctx, q, []byte(bucket), []byte(key))
row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key))
var val []byte
err = row.Scan(&val)
if err == sql.ErrNoRows {
@ -342,7 +341,7 @@ func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storag
ON CONFLICT DO NOTHING
RETURNING 1
`
row := client.pgConn.QueryRowContext(ctx, q, []byte(bucket), []byte(key), []byte(newValue))
row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(newValue))
var val []byte
err = row.Scan(&val)
if err == sql.ErrNoRows {
@ -366,7 +365,7 @@ func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storag
)
SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated
`
row = client.pgConn.QueryRowContext(ctx, q, []byte(bucket), []byte(key), []byte(oldValue))
row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue))
} else {
q := `
WITH matching_key AS (
@ -382,7 +381,7 @@ func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storag
)
SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated;
`
row = client.pgConn.QueryRowContext(ctx, q, []byte(bucket), []byte(key), []byte(oldValue), []byte(newValue))
row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue), []byte(newValue))
}
var keyPresent, valueUpdated bool

View File

@ -5,14 +5,13 @@ package postgreskv
import (
"context"
"database/sql"
"strings"
"testing"
"github.com/lib/pq"
"github.com/zeebo/errs"
"storj.io/common/testcontext"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil/pgtest"
"storj.io/storj/private/dbutil/txutil"
"storj.io/storj/storage"
@ -48,10 +47,8 @@ func TestSuite(t *testing.T) {
func TestThatMigrationActuallyHappened(t *testing.T) {
store, cleanup := newTestPostgres(t)
defer cleanup()
ctx := testcontext.New(t)
defer ctx.Cleanup()
rows, err := store.pgConn.QueryContext(ctx, `
rows, err := store.pgConn.Query(`
SELECT prosrc
FROM pg_catalog.pg_proc p,
pg_catalog.pg_namespace n
@ -92,9 +89,9 @@ func BenchmarkSuite(b *testing.B) {
testsuite.RunBenchmarks(b, store)
}
func bulkImport(ctx context.Context, db dbwrap.DB, iter storage.Iterator) error {
return txutil.WithTx(ctx, db, nil, func(ctx context.Context, txn dbwrap.Tx) (err error) {
stmt, err := txn.PrepareContext(ctx, pq.CopyIn("pathdata", "bucket", "fullpath", "metadata"))
func bulkImport(ctx context.Context, db *sql.DB, iter storage.Iterator) error {
return txutil.WithTx(ctx, db, nil, func(ctx context.Context, txn *sql.Tx) (err error) {
stmt, err := txn.Prepare(pq.CopyIn("pathdata", "bucket", "fullpath", "metadata"))
if err != nil {
return errs.New("Failed to initialize COPY FROM: %v", err)
}
@ -107,19 +104,19 @@ func bulkImport(ctx context.Context, db dbwrap.DB, iter storage.Iterator) error
var item storage.ListItem
for iter.Next(ctx, &item) {
if _, err := stmt.ExecContext(ctx, []byte(""), []byte(item.Key), []byte(item.Value)); err != nil {
if _, err := stmt.Exec([]byte(""), []byte(item.Key), []byte(item.Value)); err != nil {
return err
}
}
if _, err = stmt.ExecContext(ctx); err != nil {
if _, err = stmt.Exec(); err != nil {
return errs.New("Failed to complete COPY FROM: %v", err)
}
return nil
})
}
func bulkDeleteAll(db dbwrap.DB) error {
_, err := db.ExecContext(context.TODO(), "TRUNCATE pathdata")
func bulkDeleteAll(db *sql.DB) error {
_, err := db.Exec("TRUNCATE pathdata")
if err != nil {
return errs.New("Failed to TRUNCATE pathdata table: %v", err)
}

View File

@ -13,7 +13,6 @@ import (
"gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/pgutil"
"storj.io/storj/private/dbutil/txutil"
"storj.io/storj/storage"
@ -31,8 +30,7 @@ var (
// Client is the entrypoint into a postgreskv2 data store
type Client struct {
db *sql.DB
pgConn dbwrap.DB
db *sql.DB
}
// New instantiates a new postgreskv2 client given db URL
@ -44,7 +42,7 @@ func New(dbURL string) (*Client, error) {
return nil, err
}
dbutil.Configure(dbwrap.SQLDB(db), mon)
dbutil.Configure(db, mon)
err = schema.PrepareDB(db)
if err != nil {
@ -56,10 +54,7 @@ func New(dbURL string) (*Client, error) {
// NewWith instantiates a new postgreskv client given db.
func NewWith(db *sql.DB) *Client {
return &Client{
db: db,
pgConn: dbwrap.SQLDB(db),
}
return &Client{db: db}
}
// Close closes the client
@ -230,7 +225,7 @@ func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldVa
return Error.Wrap(err)
}
return txutil.WithTx(ctx, client.pgConn, nil, func(_ context.Context, txn dbwrap.Tx) error {
return txutil.WithTx(ctx, client.db, nil, func(_ context.Context, txn *sql.Tx) error {
q := "SELECT metadata FROM pathdata WHERE fullpath = $1::BYTEA;"
row := txn.QueryRowContext(ctx, q, []byte(key))

View File

@ -16,7 +16,6 @@ import (
"gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/private/dbutil/sqliteutil"
"storj.io/storj/private/migrate"
"storj.io/storj/storage"
@ -51,14 +50,18 @@ var _ storagenode.DB = (*DB)(nil)
type SQLDB interface {
Close() error
BeginTx(ctx context.Context, opts *sql.TxOptions) (dbwrap.Tx, error)
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Conn(ctx context.Context) (dbwrap.Conn, error)
DriverContext(context.Context) driver.Driver
Conn(ctx context.Context) (*sql.Conn, error)
Driver() driver.Driver
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
@ -69,7 +72,7 @@ type DBContainer interface {
}
// withTx is a helper method which executes callback in transaction scope
func withTx(ctx context.Context, db SQLDB, cb func(tx dbwrap.Tx) error) error {
func withTx(ctx context.Context, db SQLDB, cb func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
@ -266,9 +269,9 @@ func (db *DB) openDatabase(dbName string) error {
}
mDB := db.SQLDBs[dbName]
mDB.Configure(dbwrap.SQLDB(sqlDB))
mDB.Configure(sqlDB)
dbutil.Configure(dbwrap.SQLDB(sqlDB), mon)
dbutil.Configure(sqlDB, mon)
return nil
}
@ -663,7 +666,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from trash data",
Version: 13,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "blob/ukfu6bhbboxilvt7jrwlqk7y2tapb5d2r2tsmj2sjxvw5qaaaaaa")) // us-central1
if err != nil {
log.Sugar().Debug(err)
@ -688,7 +691,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from orphaned tmp data",
Version: 14,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "tmp"))
if err != nil {
log.Sugar().Debug(err)
@ -829,7 +832,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Vacuum info db",
Version: 22,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
_, err := db.deprecatedInfoDB.GetDB().ExecContext(ctx, "VACUUM;")
return err
}),
@ -838,7 +841,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Split into multiple sqlite databases",
Version: 23,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// Migrate all the tables to new database files.
if err := db.migrateToDB(ctx, BandwidthDBName, "bandwidth_usage", "bandwidth_usage_rollups"); err != nil {
return ErrDatabase.Wrap(err)
@ -875,7 +878,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Drop unneeded tables in deprecatedInfoDB",
Version: 24,
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx dbwrap.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// We drop the migrated tables from the deprecated database and VACUUM SQLite3
// in migration step 23 because if we were to keep that as part of step 22
// and an error occurred it would replay the entire migration but some tables

View File

@ -13,7 +13,6 @@ import (
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/storagenode/orders"
)
@ -179,7 +178,7 @@ func (db *ordersDB) ListUnsentBySatellite(ctx context.Context) (_ map[storj.Node
func (db *ordersDB) Archive(ctx context.Context, archivedAt time.Time, requests ...orders.ArchiveRequest) (err error) {
defer mon.Task()(&ctx)(&err)
txn, err := db.BeginTx(ctx, nil)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return ErrOrders.Wrap(err)
}
@ -187,7 +186,7 @@ func (db *ordersDB) Archive(ctx context.Context, archivedAt time.Time, requests
var notFoundErrs errs.Group
defer func() {
if err == nil {
err = txn.Commit()
err = tx.Commit()
if err == nil {
if len(notFoundErrs) > 0 {
// Return a class error to allow to the caler to identify this case
@ -195,12 +194,12 @@ func (db *ordersDB) Archive(ctx context.Context, archivedAt time.Time, requests
}
}
} else {
err = errs.Combine(err, txn.Rollback())
err = errs.Combine(err, tx.Rollback())
}
}()
for _, req := range requests {
err := db.archiveOne(ctx, txn, archivedAt, req)
err := db.archiveOne(ctx, tx, archivedAt, req)
if err != nil {
if orders.OrderNotFoundError.Has(err) {
notFoundErrs.Add(err)
@ -215,10 +214,10 @@ func (db *ordersDB) Archive(ctx context.Context, archivedAt time.Time, requests
}
// archiveOne marks order as being handled.
func (db *ordersDB) archiveOne(ctx context.Context, txn dbwrap.Tx, archivedAt time.Time, req orders.ArchiveRequest) (err error) {
func (db *ordersDB) archiveOne(ctx context.Context, tx *sql.Tx, archivedAt time.Time, req orders.ArchiveRequest) (err error) {
defer mon.Task()(&ctx)(&err)
result, err := txn.ExecContext(ctx, `
result, err := tx.ExecContext(ctx, `
INSERT INTO order_archive_ (
satellite_id, serial_number,
order_limit_serialized, order_serialized,

View File

@ -33,7 +33,6 @@ type pieceSpaceUsedDB struct {
// Init creates the total pieces and total trash records if they don't already exist
func (db *pieceSpaceUsedDB) Init(ctx context.Context) (err error) {
defer mon.Task()(&ctx)(&err)
totalPiecesRow := db.QueryRowContext(ctx, `
SELECT total
FROM piece_space_used
@ -73,7 +72,6 @@ func (db *pieceSpaceUsedDB) Init(ctx context.Context) (err error) {
}
func (db *pieceSpaceUsedDB) createInitTotalPieces(ctx context.Context) (err error) {
defer mon.Task()(&ctx)(&err)
_, err = db.ExecContext(ctx, `
INSERT INTO piece_space_used (total) VALUES (0)
`)
@ -81,7 +79,6 @@ func (db *pieceSpaceUsedDB) createInitTotalPieces(ctx context.Context) (err erro
}
func (db *pieceSpaceUsedDB) createInitTotalTrash(ctx context.Context) (err error) {
defer mon.Task()(&ctx)(&err)
_, err = db.ExecContext(ctx, `
INSERT INTO piece_space_used (total, satellite_id) VALUES (0, ?)
`, trashTotalRowName)

View File

@ -5,12 +5,12 @@ package storagenodedb
import (
"context"
"database/sql"
"time"
"github.com/zeebo/errs"
"storj.io/common/storj"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/storagenode/satellites"
)
@ -47,7 +47,7 @@ func (db *satellitesDB) GetSatellite(ctx context.Context, satelliteID storj.Node
// InitiateGracefulExit updates the database to reflect the beginning of a graceful exit
func (db *satellitesDB) InitiateGracefulExit(ctx context.Context, satelliteID storj.NodeID, intitiatedAt time.Time, startingDiskUsage int64) (err error) {
defer mon.Task()(&ctx)(&err)
return ErrSatellitesDB.Wrap(withTx(ctx, db.GetDB(), func(tx dbwrap.Tx) error {
return ErrSatellitesDB.Wrap(withTx(ctx, db.GetDB(), func(tx *sql.Tx) error {
query := `INSERT OR REPLACE INTO satellites (node_id, status, added_at) VALUES (?,?, COALESCE((SELECT added_at FROM satellites WHERE node_id = ?), ?))`
_, err = tx.ExecContext(ctx, query, satelliteID, satellites.Exiting, satelliteID, intitiatedAt.UTC()) // assume intitiatedAt < time.Now()
if err != nil {
@ -78,7 +78,7 @@ func (db *satellitesDB) UpdateGracefulExit(ctx context.Context, satelliteID stor
// CompleteGracefulExit updates the database when a graceful exit is completed or failed
func (db *satellitesDB) CompleteGracefulExit(ctx context.Context, satelliteID storj.NodeID, finishedAt time.Time, exitStatus satellites.Status, completionReceipt []byte) (err error) {
defer mon.Task()(&ctx)(&err)
return ErrSatellitesDB.Wrap(withTx(ctx, db.GetDB(), func(tx dbwrap.Tx) error {
return ErrSatellitesDB.Wrap(withTx(ctx, db.GetDB(), func(tx *sql.Tx) error {
query := `UPDATE satellites SET status = ? WHERE node_id = ?`
_, err = tx.ExecContext(ctx, query, exitStatus, satelliteID)
if err != nil {

View File

@ -11,7 +11,6 @@ import (
"github.com/zeebo/errs"
"storj.io/common/storj"
"storj.io/storj/private/dbutil/dbwrap"
"storj.io/storj/storagenode/storageusage"
)
@ -34,7 +33,7 @@ func (db *storageUsageDB) Store(ctx context.Context, stamps []storageusage.Stamp
query := `INSERT OR REPLACE INTO storage_usage(satellite_id, at_rest_total, interval_start)
VALUES(?,?,?)`
return withTx(ctx, db.GetDB(), func(tx dbwrap.Tx) error {
return withTx(ctx, db.GetDB(), func(tx *sql.Tx) error {
for _, stamp := range stamps {
_, err = tx.ExecContext(ctx, query, stamp.SatelliteID, stamp.AtRestTotal, stamp.IntervalStart.UTC())