storagenodedb: reenable utccheck in tests
Change-Id: If7d64dd4ae58e4b656ff9122ae3195b2a5173cb3
This commit is contained in:
parent
5ed9373dba
commit
fb8e78132d
@ -6,12 +6,20 @@ package sqliteutil
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
// DB is the minimal interface required to perform migrations.
|
||||
type DB interface {
|
||||
Conn(ctx context.Context) (*sql.Conn, error)
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrMigrateTables is error class for MigrateTables
|
||||
ErrMigrateTables = errs.Class("migrate tables:")
|
||||
@ -20,10 +28,24 @@ var (
|
||||
ErrKeepTables = errs.Class("keep tables:")
|
||||
)
|
||||
|
||||
// getSqlite3Conn attempts to get a *sqlite3.SQLiteConn from the connection.
|
||||
func getSqlite3Conn(conn interface{}) (*sqlite3.SQLiteConn, error) {
|
||||
for {
|
||||
switch c := conn.(type) {
|
||||
case *sqlite3.SQLiteConn:
|
||||
return c, nil
|
||||
case interface{ Unwrap() driver.Conn }:
|
||||
conn = c.Unwrap()
|
||||
default:
|
||||
return nil, ErrMigrateTables.New("unable to get raw database connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateTablesToDatabase copies the specified tables from srcDB into destDB.
|
||||
// All tables in destDB will be dropped other than those specified in
|
||||
// tablesToKeep.
|
||||
func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error {
|
||||
func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB DB, tablesToKeep ...string) error {
|
||||
err := backupDBs(ctx, srcDB, destDB)
|
||||
if err != nil {
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
@ -33,7 +55,7 @@ func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesT
|
||||
return ErrMigrateTables.Wrap(KeepTables(ctx, destDB, tablesToKeep...))
|
||||
}
|
||||
|
||||
func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error {
|
||||
func backupDBs(ctx context.Context, srcDB, destDB DB) error {
|
||||
// Retrieve the raw Sqlite3 driver connections for the src and dest so that
|
||||
// we can execute the backup API for a corruption safe clone.
|
||||
srcConn, err := srcDB.Conn(ctx)
|
||||
@ -57,15 +79,15 @@ func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error {
|
||||
// The references to the driver connections are only guaranteed to be valid
|
||||
// for the life of the callback so we must do the work within both callbacks.
|
||||
err = srcConn.Raw(func(srcDriverConn interface{}) error {
|
||||
srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn)
|
||||
if !ok {
|
||||
return ErrMigrateTables.New("unable to get database driver")
|
||||
srcSqliteConn, err := getSqlite3Conn(srcDriverConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := destConn.Raw(func(destDriverConn interface{}) error {
|
||||
destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn)
|
||||
if !ok {
|
||||
return ErrMigrateTables.New("unable to get database driver")
|
||||
err = destConn.Raw(func(destDriverConn interface{}) error {
|
||||
destSqliteConn, err := getSqlite3Conn(destDriverConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ErrMigrateTables.Wrap(backupConns(ctx, srcSqliteConn, destSqliteConn))
|
||||
@ -138,7 +160,7 @@ func backupConns(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqli
|
||||
}
|
||||
|
||||
// KeepTables drops all the tables except the specified tables to keep.
|
||||
func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) {
|
||||
func KeepTables(ctx context.Context, db DB, tablesToKeep ...string) (err error) {
|
||||
err = dropTables(ctx, db, tablesToKeep...)
|
||||
if err != nil {
|
||||
return ErrKeepTables.Wrap(err)
|
||||
@ -156,7 +178,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err er
|
||||
}
|
||||
|
||||
// dropTables performs the table drops in a single transaction
|
||||
func dropTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) {
|
||||
func dropTables(ctx context.Context, db DB, tablesToKeep ...string) (err error) {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return ErrKeepTables.Wrap(err)
|
||||
|
@ -5,99 +5,294 @@ package utccheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
// TODO: implement this in terms of a driver rather than as a wrapper for DB.
|
||||
|
||||
// DB wraps a sql.DB and checks all of the arguments to queries to ensure they are in UTC.
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
// Connector wraps a driver.Connector with utc checks.
|
||||
type Connector struct {
|
||||
connector driver.Connector
|
||||
}
|
||||
|
||||
// New creates a new database that checks that all time arguments are UTC.
|
||||
func New(db *sql.DB) *DB {
|
||||
return &DB{DB: db}
|
||||
// WrapConnector wraps a driver.Connector with utc checks.
|
||||
func WrapConnector(connector driver.Connector) *Connector {
|
||||
return &Connector{connector: connector}
|
||||
}
|
||||
|
||||
// Close closes the database.
|
||||
func (db DB) Close() error { return db.DB.Close() }
|
||||
// Unwrap returns the underlying driver.Connector.
|
||||
func (c *Connector) Unwrap() driver.Connector { return c.connector }
|
||||
|
||||
// Query executes Query after checking all of the arguments.
|
||||
func (db DB) Query(sql string, args ...interface{}) (*sql.Rows, error) {
|
||||
// Connect returns a wrapped driver.Conn with utc checks.
|
||||
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
conn, err := c.connector.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return WrapConn(conn), nil
|
||||
}
|
||||
|
||||
// Driver returns a wrapped driver.Driver with utc checks.
|
||||
func (c *Connector) Driver() driver.Driver {
|
||||
return WrapDriver(c.connector.Driver())
|
||||
}
|
||||
|
||||
//
|
||||
// driver
|
||||
//
|
||||
|
||||
// Driver wraps a driver.Driver with utc checks.
|
||||
type Driver struct {
|
||||
driver driver.Driver
|
||||
}
|
||||
|
||||
// WrapDriver wraps a driver.Driver with utc checks.
|
||||
func WrapDriver(driver driver.Driver) *Driver {
|
||||
return &Driver{driver: driver}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying driver.Driver.
|
||||
func (d *Driver) Unwrap() driver.Driver { return d.driver }
|
||||
|
||||
// Open returns a wrapped driver.Conn with utc checks.
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
conn, err := d.driver.Open(name)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return WrapConn(conn), nil
|
||||
}
|
||||
|
||||
//
|
||||
// conn
|
||||
//
|
||||
|
||||
// Conn wraps a driver.Conn with utc checks.
|
||||
type Conn struct {
|
||||
conn driver.Conn
|
||||
}
|
||||
|
||||
// WrapConn wraps a driver.Conn with utc checks.
|
||||
func WrapConn(conn driver.Conn) *Conn {
|
||||
return &Conn{conn: conn}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying driver.Conn.
|
||||
func (c *Conn) Unwrap() driver.Conn { return c.conn }
|
||||
|
||||
// Close closes the conn.
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping implements driver.Pinger.
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
// sqlite3 implements this
|
||||
return c.conn.(driver.Pinger).Ping(ctx)
|
||||
}
|
||||
|
||||
// Begin returns a wrapped driver.Tx with utc checks.
|
||||
func (c *Conn) Begin() (driver.Tx, error) {
|
||||
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
|
||||
//nolint
|
||||
tx, err := c.conn.Begin()
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return WrapTx(tx), nil
|
||||
}
|
||||
|
||||
// BeginTx returns a wrapped driver.Tx with utc checks.
|
||||
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
// sqlite3 implements this
|
||||
tx, err := c.conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return WrapTx(tx), nil
|
||||
}
|
||||
|
||||
// Query checks the arguments for non-utc timestamps and returns the result.
|
||||
func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.Query(sql, args...)
|
||||
|
||||
// sqlite3 implements this
|
||||
//
|
||||
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
|
||||
//nolint
|
||||
return c.conn.(driver.Queryer).Query(query, args)
|
||||
}
|
||||
|
||||
// QueryRow executes QueryRow after checking all of the arguments.
|
||||
func (db DB) QueryRow(sql string, args ...interface{}) *sql.Row {
|
||||
// TODO(jeff): figure out a way to return an errored *sql.Row so we can consider
|
||||
// enabling all of these checks in production.
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
panic(err)
|
||||
// QueryContext checks the arguments for non-utc timestamps and returns the result.
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
if err := utcCheckNamedArgs(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.QueryRow(sql, args...)
|
||||
|
||||
// sqlite3 implements this
|
||||
return c.conn.(driver.QueryerContext).QueryContext(ctx, query, args)
|
||||
}
|
||||
|
||||
// QueryContext executes QueryContext after checking all of the arguments.
|
||||
func (db DB) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
|
||||
// Exec checks the arguments for non-utc timestamps and returns the result.
|
||||
func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.QueryContext(ctx, sql, args...)
|
||||
|
||||
// sqlite3 implements this
|
||||
//
|
||||
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
|
||||
//nolint
|
||||
return c.conn.(driver.Execer).Exec(query, args)
|
||||
}
|
||||
|
||||
// QueryRowContext executes QueryRowContext after checking all of the arguments.
|
||||
func (db DB) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row {
|
||||
// TODO(jeff): figure out a way to return an errored *sql.Row so we can consider
|
||||
// enabling all of these checks in production.
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db.DB.QueryRowContext(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec executes Exec after checking all of the arguments.
|
||||
func (db DB) Exec(sql string, args ...interface{}) (sql.Result, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
// ExecContext checks the arguments for non-utc timestamps and returns the result.
|
||||
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if err := utcCheckNamedArgs(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.Exec(sql, args...)
|
||||
|
||||
// sqlite3 implements this
|
||||
return c.conn.(driver.ExecerContext).ExecContext(ctx, query, args)
|
||||
}
|
||||
|
||||
// ExecContext executes ExecContext after checking all of the arguments.
|
||||
func (db DB) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
return nil, err
|
||||
// Prepare returns a wrapped driver.Stmt with utc checks.
|
||||
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return db.DB.ExecContext(ctx, sql, args...)
|
||||
return WrapStmt(stmt), nil
|
||||
}
|
||||
|
||||
// utcCheckArgs checks the arguments for time.Time values that are not in the UTC location.
|
||||
func utcCheckArgs(args []interface{}) error {
|
||||
// PrepareContext checks the arguments for non-utc timestamps and returns the result.
|
||||
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
// sqlite3 implements this
|
||||
stmt, err := c.conn.(driver.ConnPrepareContext).PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
return WrapStmt(stmt), nil
|
||||
}
|
||||
|
||||
//
|
||||
// stmt
|
||||
//
|
||||
|
||||
// Stmt wraps a driver.Stmt with utc checks.
|
||||
type Stmt struct {
|
||||
stmt driver.Stmt
|
||||
}
|
||||
|
||||
// WrapStmt wraps a driver.Stmt with utc checks.
|
||||
func WrapStmt(stmt driver.Stmt) *Stmt {
|
||||
return &Stmt{stmt: stmt}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying driver.Stmt.
|
||||
func (s *Stmt) Unwrap() driver.Stmt { return s.stmt }
|
||||
|
||||
// Close closes the stmt.
|
||||
func (s *Stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
// NumInput returns the number of inputs to the stmt.
|
||||
func (s *Stmt) NumInput() int {
|
||||
return s.stmt.NumInput()
|
||||
}
|
||||
|
||||
// Exec checks the arguments for non-utc timestamps and returns the result.
|
||||
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
|
||||
//nolint
|
||||
return s.stmt.Exec(args)
|
||||
}
|
||||
|
||||
// Query checks the arguments for non-utc timestamps and returns the result.
|
||||
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if err := utcCheckArgs(args); err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
|
||||
//nolint
|
||||
return s.stmt.Query(args)
|
||||
}
|
||||
|
||||
//
|
||||
// tx
|
||||
//
|
||||
|
||||
// Tx wraps a driver.Tx with utc checks.
|
||||
type Tx struct {
|
||||
tx driver.Tx
|
||||
}
|
||||
|
||||
// WrapTx wraps a driver.Tx with utc checks.
|
||||
func WrapTx(tx driver.Tx) *Tx {
|
||||
return &Tx{tx: tx}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying driver.Tx.
|
||||
func (t *Tx) Unwrap() driver.Tx { return t.tx }
|
||||
|
||||
// Commit commits the tx.
|
||||
func (t *Tx) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
// Rollback rolls the tx back.
|
||||
func (t *Tx) Rollback() error {
|
||||
return t.tx.Rollback()
|
||||
}
|
||||
|
||||
//
|
||||
// helpers
|
||||
//
|
||||
|
||||
func utcCheckArg(n int, arg interface{}) error {
|
||||
var t time.Time
|
||||
var ok bool
|
||||
|
||||
switch a := arg.(type) {
|
||||
case time.Time:
|
||||
t, ok = a, true
|
||||
case *time.Time:
|
||||
if a != nil {
|
||||
t, ok = *a, true
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
} else if loc := t.Location(); loc != time.UTC {
|
||||
return errs.New("invalid timezone on argument %d: %v", n, loc)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func utcCheckNamedArgs(args []driver.NamedValue) error {
|
||||
for n, arg := range args {
|
||||
var t time.Time
|
||||
var ok bool
|
||||
|
||||
switch a := arg.(type) {
|
||||
case time.Time:
|
||||
t, ok = a, true
|
||||
case *time.Time:
|
||||
if a != nil {
|
||||
t, ok = *a, true
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if loc := t.Location(); loc != time.UTC {
|
||||
return errs.New("invalid timezone on argument %d: %v", n, loc)
|
||||
if err := utcCheckArg(n, arg.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func utcCheckArgs(args []driver.Value) error {
|
||||
for n, arg := range args {
|
||||
if err := utcCheckArg(n, arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -17,7 +17,7 @@ import (
|
||||
|
||||
func TestUTCDB(t *testing.T) {
|
||||
notUTC := time.FixedZone("not utc", -1)
|
||||
db := utccheck.New(sql.OpenDB(emptyConnector{}))
|
||||
db := sql.OpenDB(utccheck.WrapConnector(emptyConnector{}))
|
||||
|
||||
{ // time.Time not in UTC
|
||||
_, err := db.Exec("", time.Now().In(notUTC))
|
||||
@ -58,9 +58,27 @@ func (emptyConnector) Driver() driver.Driver { return nil
|
||||
|
||||
type emptyConn struct{}
|
||||
|
||||
func (emptyConn) Close() error { return nil }
|
||||
|
||||
func (emptyConn) Prepare(query string) (driver.Stmt, error) { return emptyStmt{}, nil }
|
||||
func (emptyConn) Close() error { return nil }
|
||||
func (emptyConn) Begin() (driver.Tx, error) { return emptyTx{}, nil }
|
||||
func (emptyConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
return emptyStmt{}, nil
|
||||
}
|
||||
|
||||
func (emptyConn) Begin() (driver.Tx, error) { return emptyTx{}, nil }
|
||||
func (emptyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
return emptyTx{}, nil
|
||||
}
|
||||
|
||||
func (emptyConn) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, nil }
|
||||
func (emptyConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (emptyConn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, nil }
|
||||
func (emptyConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type emptyTx struct{}
|
||||
|
||||
|
@ -30,7 +30,7 @@ type bandwidthDB struct {
|
||||
usedMu sync.RWMutex
|
||||
usedSince time.Time
|
||||
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Add adds bandwidth usage to the table
|
||||
|
@ -41,14 +41,33 @@ var (
|
||||
|
||||
var _ storagenode.DB = (*DB)(nil)
|
||||
|
||||
// SQLDB defines an interface to allow accessing and setting an sql.DB
|
||||
// SQLDB is an abstract database so that we can mock out what database
|
||||
// implementation we're using.
|
||||
type SQLDB interface {
|
||||
Configure(sqlDB *sql.DB)
|
||||
GetDB() *sql.DB
|
||||
Close() error
|
||||
|
||||
Begin() (*sql.Tx, error)
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
|
||||
Conn(ctx context.Context) (*sql.Conn, error)
|
||||
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// DBContainer defines an interface to allow accessing and setting a SQLDB
|
||||
type DBContainer interface {
|
||||
Configure(sqlDB SQLDB)
|
||||
GetDB() SQLDB
|
||||
}
|
||||
|
||||
// withTx is a helper method which executes callback in transaction scope
|
||||
func withTx(ctx context.Context, db *sql.DB, cb func(tx *sql.Tx) error) error {
|
||||
func withTx(ctx context.Context, db SQLDB, cb func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -70,13 +89,14 @@ type Config struct {
|
||||
Storage string
|
||||
Info string
|
||||
Info2 string
|
||||
|
||||
Pieces string
|
||||
Driver string // if unset, uses sqlite3
|
||||
Pieces string
|
||||
}
|
||||
|
||||
// DB contains access to different database tables
|
||||
type DB struct {
|
||||
log *zap.Logger
|
||||
log *zap.Logger
|
||||
config Config
|
||||
|
||||
pieces storage.Blobs
|
||||
|
||||
@ -93,7 +113,7 @@ type DB struct {
|
||||
usedSerialsDB *usedSerialsDB
|
||||
satellitesDB *satellitesDB
|
||||
|
||||
sqlDatabases map[string]SQLDB
|
||||
SQLDBs map[string]DBContainer
|
||||
}
|
||||
|
||||
// New creates a new master database for storage node
|
||||
@ -117,6 +137,8 @@ func New(log *zap.Logger, config Config) (*DB, error) {
|
||||
|
||||
db := &DB{
|
||||
log: log,
|
||||
config: config,
|
||||
|
||||
pieces: pieces,
|
||||
|
||||
dbDirectory: filepath.Dir(config.Info2),
|
||||
@ -132,7 +154,7 @@ func New(log *zap.Logger, config Config) (*DB, error) {
|
||||
usedSerialsDB: usedSerialsDB,
|
||||
satellitesDB: satellitesDB,
|
||||
|
||||
sqlDatabases: map[string]SQLDB{
|
||||
SQLDBs: map[string]DBContainer{
|
||||
DeprecatedInfoDBName: deprecatedInfoDB,
|
||||
PieceInfoDBName: v0PieceInfoDB,
|
||||
BandwidthDBName: bandwidthDB,
|
||||
@ -211,8 +233,8 @@ func (db *DB) openDatabases() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) rawDatabaseFromName(dbName string) *sql.DB {
|
||||
return db.sqlDatabases[dbName].GetDB()
|
||||
func (db *DB) rawDatabaseFromName(dbName string) SQLDB {
|
||||
return db.SQLDBs[dbName].GetDB()
|
||||
}
|
||||
|
||||
// openDatabase opens or creates a database at the specified path.
|
||||
@ -222,12 +244,17 @@ func (db *DB) openDatabase(dbName string) error {
|
||||
return ErrDatabase.Wrap(err)
|
||||
}
|
||||
|
||||
sqlDB, err := sql.Open("sqlite3", "file:"+path+"?_journal=WAL&_busy_timeout=10000")
|
||||
driver := db.config.Driver
|
||||
if driver == "" {
|
||||
driver = "sqlite3"
|
||||
}
|
||||
|
||||
sqlDB, err := sql.Open(driver, "file:"+path+"?_journal=WAL&_busy_timeout=10000")
|
||||
if err != nil {
|
||||
return ErrDatabase.Wrap(err)
|
||||
}
|
||||
|
||||
mDB := db.sqlDatabases[dbName]
|
||||
mDB := db.SQLDBs[dbName]
|
||||
mDB.Configure(sqlDB)
|
||||
|
||||
dbutil.Configure(sqlDB, mon)
|
||||
@ -259,7 +286,7 @@ func (db *DB) Close() error {
|
||||
func (db *DB) closeDatabases() error {
|
||||
var errlist errs.Group
|
||||
|
||||
for k := range db.sqlDatabases {
|
||||
for k := range db.SQLDBs {
|
||||
errlist.Add(db.closeDatabase(k))
|
||||
}
|
||||
return errlist.Err()
|
||||
@ -267,7 +294,7 @@ func (db *DB) closeDatabases() error {
|
||||
|
||||
// closeDatabase closes the specified SQLite database connections and removes them from the associated maps.
|
||||
func (db *DB) closeDatabase(dbName string) (err error) {
|
||||
mdb, ok := db.sqlDatabases[dbName]
|
||||
mdb, ok := db.SQLDBs[dbName]
|
||||
if !ok {
|
||||
return ErrDatabase.New("no database with name %s found. database was never opened or already closed.", dbName)
|
||||
}
|
||||
@ -325,8 +352,8 @@ func (db *DB) Satellites() satellites.DB {
|
||||
}
|
||||
|
||||
// RawDatabases are required for testing purposes
|
||||
func (db *DB) RawDatabases() map[string]SQLDB {
|
||||
return db.sqlDatabases
|
||||
func (db *DB) RawDatabases() map[string]DBContainer {
|
||||
return db.SQLDBs
|
||||
}
|
||||
|
||||
// migrateToDB is a helper method that performs the migration from the
|
||||
|
@ -3,35 +3,31 @@
|
||||
|
||||
package storagenodedb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// migratableDB fulfills the migrate.DB interface and the SQLDB interface
|
||||
type migratableDB struct {
|
||||
*sql.DB
|
||||
// dbContainerImpl fulfills the migrate.DB interface and the SQLDB interface
|
||||
type dbContainerImpl struct {
|
||||
SQLDB
|
||||
}
|
||||
|
||||
// Schema returns schema
|
||||
// These are implemented because the migrate.DB interface requires them.
|
||||
// Maybe in the future we should untangle those.
|
||||
func (db *migratableDB) Schema() string {
|
||||
func (db *dbContainerImpl) Schema() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Rebind rebind parameters
|
||||
// These are implemented because the migrate.DB interface requires them.
|
||||
// Maybe in the future we should untangle those.
|
||||
func (db *migratableDB) Rebind(s string) string {
|
||||
func (db *dbContainerImpl) Rebind(s string) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// Configure sets the underlining SQLDB connection.
|
||||
func (db *migratableDB) Configure(sqlDB *sql.DB) {
|
||||
db.DB = sqlDB
|
||||
func (db *dbContainerImpl) Configure(sqlDB SQLDB) {
|
||||
db.SQLDB = sqlDB
|
||||
}
|
||||
|
||||
// GetDB returns the raw *sql.DB underlying this migratableDB
|
||||
func (db *migratableDB) GetDB() *sql.DB {
|
||||
return db.DB
|
||||
// GetDB returns the raw *sql.DB underlying this dbContainerImpl
|
||||
func (db *dbContainerImpl) GetDB() SQLDB {
|
||||
return db.SQLDB
|
||||
}
|
@ -8,5 +8,5 @@ const DeprecatedInfoDBName = "info"
|
||||
|
||||
// deprecatedInfoDB represents the database that contains the original legacy sqlite3 database.
|
||||
type deprecatedInfoDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
// insertNewData will insert any NewData from the MultiDBState into the
|
||||
// appropriate rawDB. This prepares the rawDB for the test comparing schema and
|
||||
// data.
|
||||
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.SQLDB) error {
|
||||
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
|
||||
for dbName, dbState := range mdbs.DBStates {
|
||||
if dbState.NewData == "" {
|
||||
continue
|
||||
@ -42,7 +42,7 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.
|
||||
|
||||
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's
|
||||
// schema keyed by dbName
|
||||
func getSchemas(rawDBs map[string]storagenodedb.SQLDB) (map[string]*dbschema.Schema, error) {
|
||||
func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
|
||||
schemas := make(map[string]*dbschema.Schema)
|
||||
for dbName, rawDB := range rawDBs {
|
||||
schema, err := sqliteutil.QuerySchema(rawDB.GetDB())
|
||||
@ -60,7 +60,7 @@ func getSchemas(rawDBs map[string]storagenodedb.SQLDB) (map[string]*dbschema.Sch
|
||||
|
||||
// getSchemas queries the data of each rawDB and returns a map of each rawDB's
|
||||
// data keyed by dbName
|
||||
func getData(rawDBs map[string]storagenodedb.SQLDB, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
||||
func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
||||
data := make(map[string]*dbschema.Data)
|
||||
for dbName, rawDB := range rawDBs {
|
||||
datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName])
|
||||
|
@ -23,7 +23,7 @@ var ErrOrders = errs.Class("ordersdb error")
|
||||
const OrdersDBName = "orders"
|
||||
|
||||
type ordersDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Enqueue inserts order to the unsent list
|
||||
|
@ -20,7 +20,7 @@ var ErrPieceExpiration = errs.Class("piece expiration error")
|
||||
const PieceExpirationDBName = "piece_expiration"
|
||||
|
||||
type pieceExpirationDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// GetExpired gets piece IDs that expire or have expired before the given time
|
||||
|
@ -25,7 +25,7 @@ var ErrPieceInfo = errs.Class("v0pieceinfodb error")
|
||||
const PieceInfoDBName = "pieceinfo"
|
||||
|
||||
type v0PieceInfoDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Add inserts piece information into the database.
|
||||
|
@ -19,7 +19,7 @@ var ErrPieceSpaceUsed = errs.Class("piece space used error")
|
||||
const PieceSpaceUsedDBName = "piece_spaced_used"
|
||||
|
||||
type pieceSpaceUsedDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Init creates the one total record if it doesn't already exist
|
||||
|
@ -21,7 +21,7 @@ const ReputationDBName = "reputation"
|
||||
|
||||
// reputation works with node reputation DB
|
||||
type reputationDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Store inserts or updates reputation stats into the db.
|
||||
@ -29,7 +29,7 @@ func (db *reputationDB) Store(ctx context.Context, stats reputation.Stats) (err
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
query := `INSERT OR REPLACE INTO reputation (
|
||||
satellite_id,
|
||||
satellite_id,
|
||||
uptime_success_count,
|
||||
uptime_total_count,
|
||||
uptime_reputation_alpha,
|
||||
@ -120,7 +120,7 @@ func (db *reputationDB) Get(ctx context.Context, satelliteID storj.NodeID) (_ *r
|
||||
func (db *reputationDB) All(ctx context.Context) (_ []reputation.Stats, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
query := `SELECT satellite_id,
|
||||
query := `SELECT satellite_id,
|
||||
uptime_success_count,
|
||||
uptime_total_count,
|
||||
uptime_reputation_alpha,
|
||||
|
@ -22,7 +22,7 @@ const SatellitesDBName = "satellites"
|
||||
|
||||
// reputation works with node reputation DB
|
||||
type satellitesDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// GetSatellite retrieves that satellite by ID
|
||||
|
@ -6,16 +6,23 @@ package storagenodedbtest
|
||||
// This package should be referenced only in test files!
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"storj.io/storj/private/dbutil/utccheck"
|
||||
"storj.io/storj/private/testcontext"
|
||||
"storj.io/storj/storagenode"
|
||||
"storj.io/storj/storagenode/storagenodedb"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3+utccheck", utccheck.WrapDriver(&sqlite3.SQLiteDriver{}))
|
||||
}
|
||||
|
||||
// Run method will iterate over all supported databases. Will establish
|
||||
// connection and will create tables for each DB.
|
||||
func Run(t *testing.T, test func(t *testing.T, db storagenode.DB)) {
|
||||
@ -31,6 +38,7 @@ func Run(t *testing.T, test func(t *testing.T, db storagenode.DB)) {
|
||||
Storage: storageDir,
|
||||
Info: filepath.Join(storageDir, "piecestore.db"),
|
||||
Info2: filepath.Join(storageDir, "info.db"),
|
||||
Driver: "sqlite3+utccheck",
|
||||
Pieces: storageDir,
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,7 @@ const StorageUsageDBName = "storage_usage"
|
||||
|
||||
// storageUsageDB storage usage DB
|
||||
type storageUsageDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Store stores storage usage stamps to db replacing conflicting entries
|
||||
@ -30,7 +30,7 @@ func (db *storageUsageDB) Store(ctx context.Context, stamps []storageusage.Stamp
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `INSERT OR REPLACE INTO storage_usage(satellite_id, at_rest_total, interval_start)
|
||||
query := `INSERT OR REPLACE INTO storage_usage(satellite_id, at_rest_total, interval_start)
|
||||
VALUES(?,?,?)`
|
||||
|
||||
return withTx(ctx, db.GetDB(), func(tx *sql.Tx) error {
|
||||
@ -95,7 +95,7 @@ func (db *storageUsageDB) GetDaily(ctx context.Context, satelliteID storj.NodeID
|
||||
func (db *storageUsageDB) GetDailyTotal(ctx context.Context, from, to time.Time) (_ []storageusage.Stamp, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
query := `SELECT SUM(at_rest_total), interval_start
|
||||
query := `SELECT SUM(at_rest_total), interval_start
|
||||
FROM storage_usage
|
||||
WHERE ? <= interval_start AND interval_start <= ?
|
||||
GROUP BY DATE(interval_start)
|
||||
@ -134,7 +134,7 @@ func (db *storageUsageDB) Summary(ctx context.Context, from, to time.Time) (_ fl
|
||||
defer mon.Task()(&ctx, from, to)(&err)
|
||||
var summary sql.NullFloat64
|
||||
|
||||
query := `SELECT SUM(at_rest_total)
|
||||
query := `SELECT SUM(at_rest_total)
|
||||
FROM storage_usage
|
||||
WHERE ? <= interval_start AND interval_start <= ?`
|
||||
|
||||
@ -147,7 +147,7 @@ func (db *storageUsageDB) SatelliteSummary(ctx context.Context, satelliteID stor
|
||||
defer mon.Task()(&ctx, satelliteID, from, to)(&err)
|
||||
var summary sql.NullFloat64
|
||||
|
||||
query := `SELECT SUM(at_rest_total)
|
||||
query := `SELECT SUM(at_rest_total)
|
||||
FROM storage_usage
|
||||
WHERE satellite_id = ?
|
||||
AND ? <= interval_start AND interval_start <= ?`
|
||||
|
@ -20,7 +20,7 @@ var ErrUsedSerials = errs.Class("usedserialsdb error")
|
||||
const UsedSerialsDBName = "used_serial"
|
||||
|
||||
type usedSerialsDB struct {
|
||||
migratableDB
|
||||
dbContainerImpl
|
||||
}
|
||||
|
||||
// Add adds a serial to the database.
|
||||
|
Loading…
Reference in New Issue
Block a user