From fb8e78132d13b975d39ba21e0e6ff59c5ce6a729 Mon Sep 17 00:00:00 2001 From: Jeff Wendling Date: Wed, 13 Nov 2019 09:49:22 -0700 Subject: [PATCH] storagenodedb: reenable utccheck in tests Change-Id: If7d64dd4ae58e4b656ff9122ae3195b2a5173cb3 --- private/dbutil/sqliteutil/migrator.go | 44 ++- private/dbutil/utccheck/db.go | 319 ++++++++++++++---- private/dbutil/utccheck/db_test.go | 24 +- storagenode/storagenodedb/bandwidthdb.go | 2 +- storagenode/storagenodedb/database.go | 61 +++- .../{migratableDB.go => db_container.go} | 24 +- storagenode/storagenodedb/deprecatedinfo.go | 2 +- storagenode/storagenodedb/migrations_test.go | 6 +- storagenode/storagenodedb/orders.go | 2 +- storagenode/storagenodedb/pieceexpiration.go | 2 +- storagenode/storagenodedb/pieceinfo.go | 2 +- storagenode/storagenodedb/piecespaceused.go | 2 +- storagenode/storagenodedb/reputation.go | 6 +- storagenode/storagenodedb/satellites.go | 2 +- .../storagenodedb/storagenodedbtest/run.go | 8 + storagenode/storagenodedb/storageusage.go | 10 +- storagenode/storagenodedb/usedserials.go | 2 +- 17 files changed, 392 insertions(+), 126 deletions(-) rename storagenode/storagenodedb/{migratableDB.go => db_container.go} (54%) diff --git a/private/dbutil/sqliteutil/migrator.go b/private/dbutil/sqliteutil/migrator.go index b4a47de06..7d8027833 100644 --- a/private/dbutil/sqliteutil/migrator.go +++ b/private/dbutil/sqliteutil/migrator.go @@ -6,12 +6,20 @@ package sqliteutil import ( "context" "database/sql" + "database/sql/driver" "fmt" "github.com/mattn/go-sqlite3" "github.com/zeebo/errs" ) +// DB is the minimal interface required to perform migrations. +type DB interface { + Conn(ctx context.Context) (*sql.Conn, error) + Exec(query string, args ...interface{}) (sql.Result, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + var ( // ErrMigrateTables is error class for MigrateTables ErrMigrateTables = errs.Class("migrate tables:") @@ -20,10 +28,24 @@ var ( ErrKeepTables = errs.Class("keep tables:") ) +// getSqlite3Conn attempts to get a *sqlite3.SQLiteConn from the connection. +func getSqlite3Conn(conn interface{}) (*sqlite3.SQLiteConn, error) { + for { + switch c := conn.(type) { + case *sqlite3.SQLiteConn: + return c, nil + case interface{ Unwrap() driver.Conn }: + conn = c.Unwrap() + default: + return nil, ErrMigrateTables.New("unable to get raw database connection") + } + } +} + // MigrateTablesToDatabase copies the specified tables from srcDB into destDB. // All tables in destDB will be dropped other than those specified in // tablesToKeep. -func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error { +func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB DB, tablesToKeep ...string) error { err := backupDBs(ctx, srcDB, destDB) if err != nil { return ErrMigrateTables.Wrap(err) @@ -33,7 +55,7 @@ func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesT return ErrMigrateTables.Wrap(KeepTables(ctx, destDB, tablesToKeep...)) } -func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error { +func backupDBs(ctx context.Context, srcDB, destDB DB) error { // Retrieve the raw Sqlite3 driver connections for the src and dest so that // we can execute the backup API for a corruption safe clone. srcConn, err := srcDB.Conn(ctx) @@ -57,15 +79,15 @@ func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error { // The references to the driver connections are only guaranteed to be valid // for the life of the callback so we must do the work within both callbacks. err = srcConn.Raw(func(srcDriverConn interface{}) error { - srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn) - if !ok { - return ErrMigrateTables.New("unable to get database driver") + srcSqliteConn, err := getSqlite3Conn(srcDriverConn) + if err != nil { + return err } - err := destConn.Raw(func(destDriverConn interface{}) error { - destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn) - if !ok { - return ErrMigrateTables.New("unable to get database driver") + err = destConn.Raw(func(destDriverConn interface{}) error { + destSqliteConn, err := getSqlite3Conn(destDriverConn) + if err != nil { + return err } return ErrMigrateTables.Wrap(backupConns(ctx, srcSqliteConn, destSqliteConn)) @@ -138,7 +160,7 @@ func backupConns(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqli } // KeepTables drops all the tables except the specified tables to keep. -func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) { +func KeepTables(ctx context.Context, db DB, tablesToKeep ...string) (err error) { err = dropTables(ctx, db, tablesToKeep...) if err != nil { return ErrKeepTables.Wrap(err) @@ -156,7 +178,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err er } // dropTables performs the table drops in a single transaction -func dropTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) { +func dropTables(ctx context.Context, db DB, tablesToKeep ...string) (err error) { tx, err := db.BeginTx(ctx, nil) if err != nil { return ErrKeepTables.Wrap(err) diff --git a/private/dbutil/utccheck/db.go b/private/dbutil/utccheck/db.go index ffb523fc9..bee99373d 100644 --- a/private/dbutil/utccheck/db.go +++ b/private/dbutil/utccheck/db.go @@ -5,99 +5,294 @@ package utccheck import ( "context" - "database/sql" + "database/sql/driver" "time" "github.com/zeebo/errs" ) -// TODO: implement this in terms of a driver rather than as a wrapper for DB. - -// DB wraps a sql.DB and checks all of the arguments to queries to ensure they are in UTC. -type DB struct { - *sql.DB +// Connector wraps a driver.Connector with utc checks. +type Connector struct { + connector driver.Connector } -// New creates a new database that checks that all time arguments are UTC. -func New(db *sql.DB) *DB { - return &DB{DB: db} +// WrapConnector wraps a driver.Connector with utc checks. +func WrapConnector(connector driver.Connector) *Connector { + return &Connector{connector: connector} } -// Close closes the database. -func (db DB) Close() error { return db.DB.Close() } +// Unwrap returns the underlying driver.Connector. +func (c *Connector) Unwrap() driver.Connector { return c.connector } -// Query executes Query after checking all of the arguments. -func (db DB) Query(sql string, args ...interface{}) (*sql.Rows, error) { +// Connect returns a wrapped driver.Conn with utc checks. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.connector.Connect(ctx) + if err != nil { + return nil, errs.Wrap(err) + } + return WrapConn(conn), nil +} + +// Driver returns a wrapped driver.Driver with utc checks. +func (c *Connector) Driver() driver.Driver { + return WrapDriver(c.connector.Driver()) +} + +// +// driver +// + +// Driver wraps a driver.Driver with utc checks. +type Driver struct { + driver driver.Driver +} + +// WrapDriver wraps a driver.Driver with utc checks. +func WrapDriver(driver driver.Driver) *Driver { + return &Driver{driver: driver} +} + +// Unwrap returns the underlying driver.Driver. +func (d *Driver) Unwrap() driver.Driver { return d.driver } + +// Open returns a wrapped driver.Conn with utc checks. +func (d *Driver) Open(name string) (driver.Conn, error) { + conn, err := d.driver.Open(name) + if err != nil { + return nil, errs.Wrap(err) + } + return WrapConn(conn), nil +} + +// +// conn +// + +// Conn wraps a driver.Conn with utc checks. +type Conn struct { + conn driver.Conn +} + +// WrapConn wraps a driver.Conn with utc checks. +func WrapConn(conn driver.Conn) *Conn { + return &Conn{conn: conn} +} + +// Unwrap returns the underlying driver.Conn. +func (c *Conn) Unwrap() driver.Conn { return c.conn } + +// Close closes the conn. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// Ping implements driver.Pinger. +func (c *Conn) Ping(ctx context.Context) error { + // sqlite3 implements this + return c.conn.(driver.Pinger).Ping(ctx) +} + +// Begin returns a wrapped driver.Tx with utc checks. +func (c *Conn) Begin() (driver.Tx, error) { + //lint:ignore SA1019 deprecated is fine. this is a wrapper. + //nolint + tx, err := c.conn.Begin() + if err != nil { + return nil, errs.Wrap(err) + } + return WrapTx(tx), nil +} + +// BeginTx returns a wrapped driver.Tx with utc checks. +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + // sqlite3 implements this + tx, err := c.conn.(driver.ConnBeginTx).BeginTx(ctx, opts) + if err != nil { + return nil, errs.Wrap(err) + } + return WrapTx(tx), nil +} + +// Query checks the arguments for non-utc timestamps and returns the result. +func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { if err := utcCheckArgs(args); err != nil { return nil, err } - return db.DB.Query(sql, args...) + + // sqlite3 implements this + // + //lint:ignore SA1019 deprecated is fine. this is a wrapper. + //nolint + return c.conn.(driver.Queryer).Query(query, args) } -// QueryRow executes QueryRow after checking all of the arguments. -func (db DB) QueryRow(sql string, args ...interface{}) *sql.Row { - // TODO(jeff): figure out a way to return an errored *sql.Row so we can consider - // enabling all of these checks in production. - if err := utcCheckArgs(args); err != nil { - panic(err) +// QueryContext checks the arguments for non-utc timestamps and returns the result. +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if err := utcCheckNamedArgs(args); err != nil { + return nil, err } - return db.DB.QueryRow(sql, args...) + + // sqlite3 implements this + return c.conn.(driver.QueryerContext).QueryContext(ctx, query, args) } -// QueryContext executes QueryContext after checking all of the arguments. -func (db DB) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { +// Exec checks the arguments for non-utc timestamps and returns the result. +func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err := utcCheckArgs(args); err != nil { return nil, err } - return db.DB.QueryContext(ctx, sql, args...) + + // sqlite3 implements this + // + //lint:ignore SA1019 deprecated is fine. this is a wrapper. + //nolint + return c.conn.(driver.Execer).Exec(query, args) } -// QueryRowContext executes QueryRowContext after checking all of the arguments. -func (db DB) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { - // TODO(jeff): figure out a way to return an errored *sql.Row so we can consider - // enabling all of these checks in production. - if err := utcCheckArgs(args); err != nil { - panic(err) - } - return db.DB.QueryRowContext(ctx, sql, args...) -} - -// Exec executes Exec after checking all of the arguments. -func (db DB) Exec(sql string, args ...interface{}) (sql.Result, error) { - if err := utcCheckArgs(args); err != nil { +// ExecContext checks the arguments for non-utc timestamps and returns the result. +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if err := utcCheckNamedArgs(args); err != nil { return nil, err } - return db.DB.Exec(sql, args...) + + // sqlite3 implements this + return c.conn.(driver.ExecerContext).ExecContext(ctx, query, args) } -// ExecContext executes ExecContext after checking all of the arguments. -func (db DB) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { - if err := utcCheckArgs(args); err != nil { - return nil, err +// Prepare returns a wrapped driver.Stmt with utc checks. +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + stmt, err := c.conn.Prepare(query) + if err != nil { + return nil, errs.Wrap(err) } - return db.DB.ExecContext(ctx, sql, args...) + return WrapStmt(stmt), nil } -// utcCheckArgs checks the arguments for time.Time values that are not in the UTC location. -func utcCheckArgs(args []interface{}) error { +// PrepareContext checks the arguments for non-utc timestamps and returns the result. +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + // sqlite3 implements this + stmt, err := c.conn.(driver.ConnPrepareContext).PrepareContext(ctx, query) + if err != nil { + return nil, errs.Wrap(err) + } + return WrapStmt(stmt), nil +} + +// +// stmt +// + +// Stmt wraps a driver.Stmt with utc checks. +type Stmt struct { + stmt driver.Stmt +} + +// WrapStmt wraps a driver.Stmt with utc checks. +func WrapStmt(stmt driver.Stmt) *Stmt { + return &Stmt{stmt: stmt} +} + +// Unwrap returns the underlying driver.Stmt. +func (s *Stmt) Unwrap() driver.Stmt { return s.stmt } + +// Close closes the stmt. +func (s *Stmt) Close() error { + return s.stmt.Close() +} + +// NumInput returns the number of inputs to the stmt. +func (s *Stmt) NumInput() int { + return s.stmt.NumInput() +} + +// Exec checks the arguments for non-utc timestamps and returns the result. +func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if err := utcCheckArgs(args); err != nil { + return nil, errs.Wrap(err) + } + + //lint:ignore SA1019 deprecated is fine. this is a wrapper. + //nolint + return s.stmt.Exec(args) +} + +// Query checks the arguments for non-utc timestamps and returns the result. +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if err := utcCheckArgs(args); err != nil { + return nil, errs.Wrap(err) + } + + //lint:ignore SA1019 deprecated is fine. this is a wrapper. + //nolint + return s.stmt.Query(args) +} + +// +// tx +// + +// Tx wraps a driver.Tx with utc checks. +type Tx struct { + tx driver.Tx +} + +// WrapTx wraps a driver.Tx with utc checks. +func WrapTx(tx driver.Tx) *Tx { + return &Tx{tx: tx} +} + +// Unwrap returns the underlying driver.Tx. +func (t *Tx) Unwrap() driver.Tx { return t.tx } + +// Commit commits the tx. +func (t *Tx) Commit() error { + return t.tx.Commit() +} + +// Rollback rolls the tx back. +func (t *Tx) Rollback() error { + return t.tx.Rollback() +} + +// +// helpers +// + +func utcCheckArg(n int, arg interface{}) error { + var t time.Time + var ok bool + + switch a := arg.(type) { + case time.Time: + t, ok = a, true + case *time.Time: + if a != nil { + t, ok = *a, true + } + } + + if !ok { + return nil + } else if loc := t.Location(); loc != time.UTC { + return errs.New("invalid timezone on argument %d: %v", n, loc) + } else { + return nil + } +} + +func utcCheckNamedArgs(args []driver.NamedValue) error { for n, arg := range args { - var t time.Time - var ok bool - - switch a := arg.(type) { - case time.Time: - t, ok = a, true - case *time.Time: - if a != nil { - t, ok = *a, true - } - } - if !ok { - continue - } - - if loc := t.Location(); loc != time.UTC { - return errs.New("invalid timezone on argument %d: %v", n, loc) + if err := utcCheckArg(n, arg.Value); err != nil { + return err + } + } + return nil +} + +func utcCheckArgs(args []driver.Value) error { + for n, arg := range args { + if err := utcCheckArg(n, arg); err != nil { + return err } } return nil diff --git a/private/dbutil/utccheck/db_test.go b/private/dbutil/utccheck/db_test.go index 3d0adf154..6f24b73f9 100644 --- a/private/dbutil/utccheck/db_test.go +++ b/private/dbutil/utccheck/db_test.go @@ -17,7 +17,7 @@ import ( func TestUTCDB(t *testing.T) { notUTC := time.FixedZone("not utc", -1) - db := utccheck.New(sql.OpenDB(emptyConnector{})) + db := sql.OpenDB(utccheck.WrapConnector(emptyConnector{})) { // time.Time not in UTC _, err := db.Exec("", time.Now().In(notUTC)) @@ -58,9 +58,27 @@ func (emptyConnector) Driver() driver.Driver { return nil type emptyConn struct{} +func (emptyConn) Close() error { return nil } + func (emptyConn) Prepare(query string) (driver.Stmt, error) { return emptyStmt{}, nil } -func (emptyConn) Close() error { return nil } -func (emptyConn) Begin() (driver.Tx, error) { return emptyTx{}, nil } +func (emptyConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return emptyStmt{}, nil +} + +func (emptyConn) Begin() (driver.Tx, error) { return emptyTx{}, nil } +func (emptyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return emptyTx{}, nil +} + +func (emptyConn) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, nil } +func (emptyConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, nil +} + +func (emptyConn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, nil } +func (emptyConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, nil +} type emptyTx struct{} diff --git a/storagenode/storagenodedb/bandwidthdb.go b/storagenode/storagenodedb/bandwidthdb.go index 4dbf51630..cfafeb8d6 100644 --- a/storagenode/storagenodedb/bandwidthdb.go +++ b/storagenode/storagenodedb/bandwidthdb.go @@ -30,7 +30,7 @@ type bandwidthDB struct { usedMu sync.RWMutex usedSince time.Time - migratableDB + dbContainerImpl } // Add adds bandwidth usage to the table diff --git a/storagenode/storagenodedb/database.go b/storagenode/storagenodedb/database.go index f5c05555f..5275d3be8 100644 --- a/storagenode/storagenodedb/database.go +++ b/storagenode/storagenodedb/database.go @@ -41,14 +41,33 @@ var ( var _ storagenode.DB = (*DB)(nil) -// SQLDB defines an interface to allow accessing and setting an sql.DB +// SQLDB is an abstract database so that we can mock out what database +// implementation we're using. type SQLDB interface { - Configure(sqlDB *sql.DB) - GetDB() *sql.DB + Close() error + + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + + Conn(ctx context.Context) (*sql.Conn, error) + + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// DBContainer defines an interface to allow accessing and setting a SQLDB +type DBContainer interface { + Configure(sqlDB SQLDB) + GetDB() SQLDB } // withTx is a helper method which executes callback in transaction scope -func withTx(ctx context.Context, db *sql.DB, cb func(tx *sql.Tx) error) error { +func withTx(ctx context.Context, db SQLDB, cb func(tx *sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err @@ -70,13 +89,14 @@ type Config struct { Storage string Info string Info2 string - - Pieces string + Driver string // if unset, uses sqlite3 + Pieces string } // DB contains access to different database tables type DB struct { - log *zap.Logger + log *zap.Logger + config Config pieces storage.Blobs @@ -93,7 +113,7 @@ type DB struct { usedSerialsDB *usedSerialsDB satellitesDB *satellitesDB - sqlDatabases map[string]SQLDB + SQLDBs map[string]DBContainer } // New creates a new master database for storage node @@ -117,6 +137,8 @@ func New(log *zap.Logger, config Config) (*DB, error) { db := &DB{ log: log, + config: config, + pieces: pieces, dbDirectory: filepath.Dir(config.Info2), @@ -132,7 +154,7 @@ func New(log *zap.Logger, config Config) (*DB, error) { usedSerialsDB: usedSerialsDB, satellitesDB: satellitesDB, - sqlDatabases: map[string]SQLDB{ + SQLDBs: map[string]DBContainer{ DeprecatedInfoDBName: deprecatedInfoDB, PieceInfoDBName: v0PieceInfoDB, BandwidthDBName: bandwidthDB, @@ -211,8 +233,8 @@ func (db *DB) openDatabases() error { return nil } -func (db *DB) rawDatabaseFromName(dbName string) *sql.DB { - return db.sqlDatabases[dbName].GetDB() +func (db *DB) rawDatabaseFromName(dbName string) SQLDB { + return db.SQLDBs[dbName].GetDB() } // openDatabase opens or creates a database at the specified path. @@ -222,12 +244,17 @@ func (db *DB) openDatabase(dbName string) error { return ErrDatabase.Wrap(err) } - sqlDB, err := sql.Open("sqlite3", "file:"+path+"?_journal=WAL&_busy_timeout=10000") + driver := db.config.Driver + if driver == "" { + driver = "sqlite3" + } + + sqlDB, err := sql.Open(driver, "file:"+path+"?_journal=WAL&_busy_timeout=10000") if err != nil { return ErrDatabase.Wrap(err) } - mDB := db.sqlDatabases[dbName] + mDB := db.SQLDBs[dbName] mDB.Configure(sqlDB) dbutil.Configure(sqlDB, mon) @@ -259,7 +286,7 @@ func (db *DB) Close() error { func (db *DB) closeDatabases() error { var errlist errs.Group - for k := range db.sqlDatabases { + for k := range db.SQLDBs { errlist.Add(db.closeDatabase(k)) } return errlist.Err() @@ -267,7 +294,7 @@ func (db *DB) closeDatabases() error { // closeDatabase closes the specified SQLite database connections and removes them from the associated maps. func (db *DB) closeDatabase(dbName string) (err error) { - mdb, ok := db.sqlDatabases[dbName] + mdb, ok := db.SQLDBs[dbName] if !ok { return ErrDatabase.New("no database with name %s found. database was never opened or already closed.", dbName) } @@ -325,8 +352,8 @@ func (db *DB) Satellites() satellites.DB { } // RawDatabases are required for testing purposes -func (db *DB) RawDatabases() map[string]SQLDB { - return db.sqlDatabases +func (db *DB) RawDatabases() map[string]DBContainer { + return db.SQLDBs } // migrateToDB is a helper method that performs the migration from the diff --git a/storagenode/storagenodedb/migratableDB.go b/storagenode/storagenodedb/db_container.go similarity index 54% rename from storagenode/storagenodedb/migratableDB.go rename to storagenode/storagenodedb/db_container.go index a54932f7c..8000c1cce 100644 --- a/storagenode/storagenodedb/migratableDB.go +++ b/storagenode/storagenodedb/db_container.go @@ -3,35 +3,31 @@ package storagenodedb -import ( - "database/sql" -) - -// migratableDB fulfills the migrate.DB interface and the SQLDB interface -type migratableDB struct { - *sql.DB +// dbContainerImpl fulfills the migrate.DB interface and the SQLDB interface +type dbContainerImpl struct { + SQLDB } // Schema returns schema // These are implemented because the migrate.DB interface requires them. // Maybe in the future we should untangle those. -func (db *migratableDB) Schema() string { +func (db *dbContainerImpl) Schema() string { return "" } // Rebind rebind parameters // These are implemented because the migrate.DB interface requires them. // Maybe in the future we should untangle those. -func (db *migratableDB) Rebind(s string) string { +func (db *dbContainerImpl) Rebind(s string) string { return s } // Configure sets the underlining SQLDB connection. -func (db *migratableDB) Configure(sqlDB *sql.DB) { - db.DB = sqlDB +func (db *dbContainerImpl) Configure(sqlDB SQLDB) { + db.SQLDB = sqlDB } -// GetDB returns the raw *sql.DB underlying this migratableDB -func (db *migratableDB) GetDB() *sql.DB { - return db.DB +// GetDB returns the raw *sql.DB underlying this dbContainerImpl +func (db *dbContainerImpl) GetDB() SQLDB { + return db.SQLDB } diff --git a/storagenode/storagenodedb/deprecatedinfo.go b/storagenode/storagenodedb/deprecatedinfo.go index 84f4c9b10..871c692b7 100644 --- a/storagenode/storagenodedb/deprecatedinfo.go +++ b/storagenode/storagenodedb/deprecatedinfo.go @@ -8,5 +8,5 @@ const DeprecatedInfoDBName = "info" // deprecatedInfoDB represents the database that contains the original legacy sqlite3 database. type deprecatedInfoDB struct { - migratableDB + dbContainerImpl } diff --git a/storagenode/storagenodedb/migrations_test.go b/storagenode/storagenodedb/migrations_test.go index 85397db19..16e2add63 100644 --- a/storagenode/storagenodedb/migrations_test.go +++ b/storagenode/storagenodedb/migrations_test.go @@ -22,7 +22,7 @@ import ( // insertNewData will insert any NewData from the MultiDBState into the // appropriate rawDB. This prepares the rawDB for the test comparing schema and // data. -func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.SQLDB) error { +func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error { for dbName, dbState := range mdbs.DBStates { if dbState.NewData == "" { continue @@ -42,7 +42,7 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb. // getSchemas queries the schema of each rawDB and returns a map of each rawDB's // schema keyed by dbName -func getSchemas(rawDBs map[string]storagenodedb.SQLDB) (map[string]*dbschema.Schema, error) { +func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) { schemas := make(map[string]*dbschema.Schema) for dbName, rawDB := range rawDBs { schema, err := sqliteutil.QuerySchema(rawDB.GetDB()) @@ -60,7 +60,7 @@ func getSchemas(rawDBs map[string]storagenodedb.SQLDB) (map[string]*dbschema.Sch // getSchemas queries the data of each rawDB and returns a map of each rawDB's // data keyed by dbName -func getData(rawDBs map[string]storagenodedb.SQLDB, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) { +func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) { data := make(map[string]*dbschema.Data) for dbName, rawDB := range rawDBs { datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName]) diff --git a/storagenode/storagenodedb/orders.go b/storagenode/storagenodedb/orders.go index 8e357a746..9b3e255b3 100644 --- a/storagenode/storagenodedb/orders.go +++ b/storagenode/storagenodedb/orders.go @@ -23,7 +23,7 @@ var ErrOrders = errs.Class("ordersdb error") const OrdersDBName = "orders" type ordersDB struct { - migratableDB + dbContainerImpl } // Enqueue inserts order to the unsent list diff --git a/storagenode/storagenodedb/pieceexpiration.go b/storagenode/storagenodedb/pieceexpiration.go index 2df6936e9..0a93d5f39 100644 --- a/storagenode/storagenodedb/pieceexpiration.go +++ b/storagenode/storagenodedb/pieceexpiration.go @@ -20,7 +20,7 @@ var ErrPieceExpiration = errs.Class("piece expiration error") const PieceExpirationDBName = "piece_expiration" type pieceExpirationDB struct { - migratableDB + dbContainerImpl } // GetExpired gets piece IDs that expire or have expired before the given time diff --git a/storagenode/storagenodedb/pieceinfo.go b/storagenode/storagenodedb/pieceinfo.go index 3e1dadaf5..fc378b03b 100644 --- a/storagenode/storagenodedb/pieceinfo.go +++ b/storagenode/storagenodedb/pieceinfo.go @@ -25,7 +25,7 @@ var ErrPieceInfo = errs.Class("v0pieceinfodb error") const PieceInfoDBName = "pieceinfo" type v0PieceInfoDB struct { - migratableDB + dbContainerImpl } // Add inserts piece information into the database. diff --git a/storagenode/storagenodedb/piecespaceused.go b/storagenode/storagenodedb/piecespaceused.go index e547cb43e..189d0a0ce 100644 --- a/storagenode/storagenodedb/piecespaceused.go +++ b/storagenode/storagenodedb/piecespaceused.go @@ -19,7 +19,7 @@ var ErrPieceSpaceUsed = errs.Class("piece space used error") const PieceSpaceUsedDBName = "piece_spaced_used" type pieceSpaceUsedDB struct { - migratableDB + dbContainerImpl } // Init creates the one total record if it doesn't already exist diff --git a/storagenode/storagenodedb/reputation.go b/storagenode/storagenodedb/reputation.go index 4585b4953..1d811814b 100644 --- a/storagenode/storagenodedb/reputation.go +++ b/storagenode/storagenodedb/reputation.go @@ -21,7 +21,7 @@ const ReputationDBName = "reputation" // reputation works with node reputation DB type reputationDB struct { - migratableDB + dbContainerImpl } // Store inserts or updates reputation stats into the db. @@ -29,7 +29,7 @@ func (db *reputationDB) Store(ctx context.Context, stats reputation.Stats) (err defer mon.Task()(&ctx)(&err) query := `INSERT OR REPLACE INTO reputation ( - satellite_id, + satellite_id, uptime_success_count, uptime_total_count, uptime_reputation_alpha, @@ -120,7 +120,7 @@ func (db *reputationDB) Get(ctx context.Context, satelliteID storj.NodeID) (_ *r func (db *reputationDB) All(ctx context.Context) (_ []reputation.Stats, err error) { defer mon.Task()(&ctx)(&err) - query := `SELECT satellite_id, + query := `SELECT satellite_id, uptime_success_count, uptime_total_count, uptime_reputation_alpha, diff --git a/storagenode/storagenodedb/satellites.go b/storagenode/storagenodedb/satellites.go index 7c62e9010..0f03467b8 100644 --- a/storagenode/storagenodedb/satellites.go +++ b/storagenode/storagenodedb/satellites.go @@ -22,7 +22,7 @@ const SatellitesDBName = "satellites" // reputation works with node reputation DB type satellitesDB struct { - migratableDB + dbContainerImpl } // GetSatellite retrieves that satellite by ID diff --git a/storagenode/storagenodedb/storagenodedbtest/run.go b/storagenode/storagenodedb/storagenodedbtest/run.go index c3da7b34c..50dd0743e 100644 --- a/storagenode/storagenodedb/storagenodedbtest/run.go +++ b/storagenode/storagenodedb/storagenodedbtest/run.go @@ -6,16 +6,23 @@ package storagenodedbtest // This package should be referenced only in test files! import ( + "database/sql" "path/filepath" "testing" + "github.com/mattn/go-sqlite3" "go.uber.org/zap/zaptest" + "storj.io/storj/private/dbutil/utccheck" "storj.io/storj/private/testcontext" "storj.io/storj/storagenode" "storj.io/storj/storagenode/storagenodedb" ) +func init() { + sql.Register("sqlite3+utccheck", utccheck.WrapDriver(&sqlite3.SQLiteDriver{})) +} + // Run method will iterate over all supported databases. Will establish // connection and will create tables for each DB. func Run(t *testing.T, test func(t *testing.T, db storagenode.DB)) { @@ -31,6 +38,7 @@ func Run(t *testing.T, test func(t *testing.T, db storagenode.DB)) { Storage: storageDir, Info: filepath.Join(storageDir, "piecestore.db"), Info2: filepath.Join(storageDir, "info.db"), + Driver: "sqlite3+utccheck", Pieces: storageDir, } diff --git a/storagenode/storagenodedb/storageusage.go b/storagenode/storagenodedb/storageusage.go index 7216f2a21..c9a759572 100644 --- a/storagenode/storagenodedb/storageusage.go +++ b/storagenode/storagenodedb/storageusage.go @@ -19,7 +19,7 @@ const StorageUsageDBName = "storage_usage" // storageUsageDB storage usage DB type storageUsageDB struct { - migratableDB + dbContainerImpl } // Store stores storage usage stamps to db replacing conflicting entries @@ -30,7 +30,7 @@ func (db *storageUsageDB) Store(ctx context.Context, stamps []storageusage.Stamp return nil } - query := `INSERT OR REPLACE INTO storage_usage(satellite_id, at_rest_total, interval_start) + query := `INSERT OR REPLACE INTO storage_usage(satellite_id, at_rest_total, interval_start) VALUES(?,?,?)` return withTx(ctx, db.GetDB(), func(tx *sql.Tx) error { @@ -95,7 +95,7 @@ func (db *storageUsageDB) GetDaily(ctx context.Context, satelliteID storj.NodeID func (db *storageUsageDB) GetDailyTotal(ctx context.Context, from, to time.Time) (_ []storageusage.Stamp, err error) { defer mon.Task()(&ctx)(&err) - query := `SELECT SUM(at_rest_total), interval_start + query := `SELECT SUM(at_rest_total), interval_start FROM storage_usage WHERE ? <= interval_start AND interval_start <= ? GROUP BY DATE(interval_start) @@ -134,7 +134,7 @@ func (db *storageUsageDB) Summary(ctx context.Context, from, to time.Time) (_ fl defer mon.Task()(&ctx, from, to)(&err) var summary sql.NullFloat64 - query := `SELECT SUM(at_rest_total) + query := `SELECT SUM(at_rest_total) FROM storage_usage WHERE ? <= interval_start AND interval_start <= ?` @@ -147,7 +147,7 @@ func (db *storageUsageDB) SatelliteSummary(ctx context.Context, satelliteID stor defer mon.Task()(&ctx, satelliteID, from, to)(&err) var summary sql.NullFloat64 - query := `SELECT SUM(at_rest_total) + query := `SELECT SUM(at_rest_total) FROM storage_usage WHERE satellite_id = ? AND ? <= interval_start AND interval_start <= ?` diff --git a/storagenode/storagenodedb/usedserials.go b/storagenode/storagenodedb/usedserials.go index 4ddef7d81..1efed8b4a 100644 --- a/storagenode/storagenodedb/usedserials.go +++ b/storagenode/storagenodedb/usedserials.go @@ -20,7 +20,7 @@ var ErrUsedSerials = errs.Class("usedserialsdb error") const UsedSerialsDBName = "used_serial" type usedSerialsDB struct { - migratableDB + dbContainerImpl } // Add adds a serial to the database.