private/migrate: add ctx argument

Change-Id: I3d65912d89261386413c494c7ed1576fed4dcaf4
This commit is contained in:
Egon Elbre 2020-01-13 15:44:55 +02:00
parent 24958bd7d3
commit ff267168c5
13 changed files with 63 additions and 62 deletions

View File

@ -74,7 +74,7 @@ func cmdAPIRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher on satellite api: ", err)
}
err = db.CheckVersion()
err = db.CheckVersion(ctx)
if err != nil {
zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -236,7 +236,7 @@ func cmdRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher: ", err)
}
err = db.CheckVersion()
err = db.CheckVersion(ctx)
if err != nil {
zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -74,7 +74,7 @@ func cmdRepairerRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher on repairer: ", err)
}
err = db.CheckVersion()
err = db.CheckVersion(ctx)
if err != nil {
zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -14,12 +14,12 @@ import (
var Error = errs.Class("migrate")
// Create with a previous schema check
func Create(identifier string, db DBX) error {
func Create(ctx context.Context, identifier string, db DBX) error {
// is this necessary? it's not immediately obvious why we roll back the transaction
// when the schemas match.
justRollbackPlease := errs.Class("only used to tell WithTx to do a rollback")
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) (err error) {
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) (err error) {
schema := db.Schema()
_, err = tx.Exec(db.Rebind(`CREATE TABLE IF NOT EXISTS table_schemas (id text, schemaText text);`))

View File

@ -30,19 +30,19 @@ func TestCreate_Sqlite(t *testing.T) {
defer func() { assert.NoError(t, db.Close()) }()
// should create table
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// shouldn't create a new table
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// should fail, because schema changed
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
require.Error(t, err)
// should fail, because of trying to CREATE TABLE with same name
err = migrate.Create("conflict", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
err = migrate.Create(ctx, "conflict", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
require.Error(t, err)
}
@ -74,19 +74,19 @@ func testCreateGeneric(ctx *testcontext.Context, t *testing.T, connStr string) {
defer func() { assert.NoError(t, db.Close()) }()
// should create table
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// shouldn't create a new table
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err)
// should fail, because schema changed
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err)
// should fail, because of trying to CREATE TABLE with same name
err = migrate.Create("conflict", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
err = migrate.Create(ctx, "conflict", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err)
}

View File

@ -65,7 +65,7 @@ type Step struct {
// Action is something that needs to be done
type Action interface {
Run(log *zap.Logger, db DB, tx *sql.Tx) error
Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
}
// TargetVersion returns migration with steps upto specified version
@ -101,9 +101,9 @@ func (migration *Migration) ValidateSteps() error {
}
// ValidateVersions checks that the version of the migration matches the state of the database
func (migration *Migration) ValidateVersions(log *zap.Logger) error {
func (migration *Migration) ValidateVersions(ctx context.Context, log *zap.Logger) error {
for _, step := range migration.Steps {
dbVersion, err := migration.getLatestVersion(log, step.DB)
dbVersion, err := migration.getLatestVersion(ctx, log, step.DB)
if err != nil {
return ErrValidateVersionQuery.Wrap(err)
}
@ -124,7 +124,7 @@ func (migration *Migration) ValidateVersions(log *zap.Logger) error {
}
// Run runs the migration steps
func (migration *Migration) Run(log *zap.Logger) error {
func (migration *Migration) Run(ctx context.Context, log *zap.Logger) error {
err := migration.ValidTableName()
if err != nil {
return err
@ -142,12 +142,12 @@ func (migration *Migration) Run(log *zap.Logger) error {
return Error.New("step.DB is nil for step %d", step.Version)
}
err = migration.ensureVersionTable(log, step.DB)
err = migration.ensureVersionTable(ctx, log, step.DB)
if err != nil {
return Error.New("creating version table failed: %v", err)
}
version, err := migration.getLatestVersion(log, step.DB)
version, err := migration.getLatestVersion(ctx, log, step.DB)
if err != nil {
return Error.Wrap(err)
}
@ -164,13 +164,13 @@ func (migration *Migration) Run(log *zap.Logger) error {
stepLog.Info(step.Description)
}
err = WithTx(context.Background(), step.DB, func(ctx context.Context, tx *sql.Tx) error {
err = step.Action.Run(stepLog, step.DB, tx)
err = WithTx(ctx, step.DB, func(ctx context.Context, tx *sql.Tx) error {
err = step.Action.Run(ctx, stepLog, step.DB, tx)
if err != nil {
return err
}
err = migration.addVersion(tx, step.DB, step.Version)
err = migration.addVersion(ctx, tx, step.DB, step.Version)
if err != nil {
return err
}
@ -196,8 +196,8 @@ func (migration *Migration) Run(log *zap.Logger) error {
}
// createVersionTable creates a new version table
func (migration *Migration) ensureVersionTable(log *zap.Logger, db DB) error {
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error {
func (migration *Migration) ensureVersionTable(ctx context.Context, log *zap.Logger, db DB) error {
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(rebind(db, `CREATE TABLE IF NOT EXISTS `+migration.Table+` (version int, commited_at text)`)) //nolint:misspell
return err
})
@ -205,9 +205,9 @@ func (migration *Migration) ensureVersionTable(log *zap.Logger, db DB) error {
}
// getLatestVersion finds the latest version table
func (migration *Migration) getLatestVersion(log *zap.Logger, db DB) (int, error) {
func (migration *Migration) getLatestVersion(ctx context.Context, log *zap.Logger, db DB) (int, error) {
var version sql.NullInt64
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error {
err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow(rebind(db, `SELECT MAX(version) FROM `+migration.Table)).Scan(&version)
if err == sql.ErrNoRows || !version.Valid {
version.Int64 = -1
@ -220,7 +220,7 @@ func (migration *Migration) getLatestVersion(log *zap.Logger, db DB) (int, error
}
// addVersion adds information about a new migration
func (migration *Migration) addVersion(tx *sql.Tx, db DB, version int) error {
func (migration *Migration) addVersion(ctx context.Context, tx *sql.Tx, db DB, version int) error {
_, err := tx.Exec(rebind(db, `
INSERT INTO `+migration.Table+` (version, commited_at) VALUES (?, ?)`), //nolint:misspell
version, time.Now().String(),
@ -229,19 +229,19 @@ func (migration *Migration) addVersion(tx *sql.Tx, db DB, version int) error {
}
// CurrentVersion finds the latest version for the db
func (migration *Migration) CurrentVersion(log *zap.Logger, db DB) (int, error) {
err := migration.ensureVersionTable(log, db)
func (migration *Migration) CurrentVersion(ctx context.Context, log *zap.Logger, db DB) (int, error) {
err := migration.ensureVersionTable(ctx, log, db)
if err != nil {
return -1, Error.Wrap(err)
}
return migration.getLatestVersion(log, db)
return migration.getLatestVersion(ctx, log, db)
}
// SQL statements that are executed on the database
type SQL []string
// Run runs the SQL statements
func (sql SQL) Run(log *zap.Logger, db DB, tx *sql.Tx) (err error) {
func (sql SQL) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) (err error) {
for _, query := range sql {
_, err := tx.Exec(rebind(db, query))
if err != nil {
@ -252,9 +252,9 @@ func (sql SQL) Run(log *zap.Logger, db DB, tx *sql.Tx) (err error) {
}
// Func is an arbitrary operation
type Func func(log *zap.Logger, db DB, tx *sql.Tx) error
type Func func(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
// Run runs the migration
func (fn Func) Run(log *zap.Logger, db DB, tx *sql.Tx) error {
return fn(log, db, tx)
func (fn Func) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error {
return fn(ctx, log, db, tx)
}

View File

@ -4,6 +4,7 @@
package migrate_test
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
@ -93,21 +94,21 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db *sql.DB, testDB m
DB: testDB,
Description: "Move files",
Version: 2,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(_ context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return os.Rename(ctx.File("alpha.txt"), ctx.File("beta.txt"))
}),
},
},
}
dbVersion, err := m.CurrentVersion(nil, testDB)
dbVersion, err := m.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err)
assert.Equal(t, dbVersion, -1)
err = m.Run(zap.NewNop())
err = m.Run(ctx, zap.NewNop())
assert.NoError(t, err)
dbVersion, err = m.CurrentVersion(nil, testDB)
dbVersion, err = m.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err)
assert.Equal(t, dbVersion, 2)
@ -120,7 +121,7 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db *sql.DB, testDB m
},
},
}
dbVersion, err = m2.CurrentVersion(nil, testDB)
dbVersion, err = m2.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err)
assert.Equal(t, dbVersion, 2)
@ -181,7 +182,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 1",
Version: 1,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
@ -190,7 +191,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 2",
Version: 2,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
@ -198,7 +199,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
},
}
err := m.Run(zap.NewNop())
err := m.Run(ctx, zap.NewNop())
assert.NoError(t, err)
assert.Equal(t, 2, steps)
@ -206,12 +207,12 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 3",
Version: 3,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++
return nil
}),
})
err = m.Run(zap.NewNop())
err = m.Run(ctx, zap.NewNop())
assert.NoError(t, err)
var version int
@ -256,14 +257,14 @@ func failedMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB,
Description: "Step 1",
Version: 1,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return fmt.Errorf("migration failed")
}),
},
},
}
err := m.Run(zap.NewNop())
err := m.Run(ctx, zap.NewNop())
require.Error(t, err, "migration failed")
var version sql.NullInt64

View File

@ -49,7 +49,7 @@ type DB interface {
// CreateTables initializes the database
CreateTables(ctx context.Context) error
// CheckVersion checks the database is the correct version
CheckVersion() error
CheckVersion(ctx context.Context) error
// Close closes the database
Close() error

View File

@ -62,7 +62,7 @@ func (db *satelliteDB) CreateTables(ctx context.Context) error {
// since we merged migration steps 0-69, the current db version should never be
// less than 69 unless the migration hasn't run yet
const minDBVersion = 69
dbVersion, err := migration.CurrentVersion(db.log, db.DB)
dbVersion, err := migration.CurrentVersion(ctx, db.log, db.DB)
if err != nil {
return errs.New("error current version: %+v", err)
}
@ -72,18 +72,18 @@ func (db *satelliteDB) CreateTables(ctx context.Context) error {
)
}
return migration.Run(db.log.Named("migrate"))
return migration.Run(ctx, db.log.Named("migrate"))
default:
return migrate.Create("database", db.DB)
return migrate.Create(ctx, "database", db.DB)
}
}
// CheckVersion confirms the database is at the desired version
func (db *satelliteDB) CheckVersion() error {
func (db *satelliteDB) CheckVersion(ctx context.Context) error {
switch db.implementation {
case dbutil.Postgres, dbutil.Cockroach:
migration := db.PostgresMigration()
return migration.ValidateVersions(db.log)
return migration.ValidateVersions(ctx, db.log)
default:
return nil

View File

@ -175,7 +175,7 @@ func pgMigrateTest(t *testing.T, connStr string) {
tag := fmt.Sprintf("#%d - v%d", i, step.Version)
// run migration up to a specific version
err := migrations.TargetVersion(step.Version).Run(log.Named("migrate"))
err := migrations.TargetVersion(step.Version).Run(ctx, log.Named("migrate"))
require.NoError(t, err, tag)
// find the matching expected version

View File

@ -210,13 +210,13 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
})
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
responses, err = db.processOrdersInTx(requests, storageNodeID, time.Now(), tx.Tx)
responses, err = db.processOrdersInTx(ctx, requests, storageNodeID, time.Now(), tx.Tx)
return err
})
return responses, errs.Wrap(err)
}
func (db *ordersDB) processOrdersInTx(requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx *sql.Tx) (responses []*orders.ProcessOrderResponse, err error) {
func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx *sql.Tx) (responses []*orders.ProcessOrderResponse, err error) {
now = now.UTC()
intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())

View File

@ -288,7 +288,7 @@ func (db *DB) filepathFromDBName(dbName string) string {
// CreateTables creates any necessary tables.
func (db *DB) CreateTables(ctx context.Context) error {
migration := db.Migration(ctx)
return migration.Run(db.log.Named("migration"))
return migration.Run(ctx, db.log.Named("migration"))
}
// Close closes any resources.
@ -666,7 +666,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from trash data",
Version: 13,
Action: migrate.Func(func(log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "blob/ukfu6bhbboxilvt7jrwlqk7y2tapb5d2r2tsmj2sjxvw5qaaaaaa")) // us-central1
if err != nil {
log.Sugar().Debug(err)
@ -691,7 +691,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from orphaned tmp data",
Version: 14,
Action: migrate.Func(func(log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "tmp"))
if err != nil {
log.Sugar().Debug(err)
@ -832,7 +832,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Vacuum info db",
Version: 22,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
_, err := db.deprecatedInfoDB.GetDB().Exec("VACUUM;")
return err
}),
@ -841,7 +841,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Split into multiple sqlite databases",
Version: 23,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// Migrate all the tables to new database files.
if err := db.migrateToDB(ctx, BandwidthDBName, "bandwidth_usage", "bandwidth_usage_rollups"); err != nil {
return ErrDatabase.Wrap(err)
@ -878,7 +878,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB,
Description: "Drop unneeded tables in deprecatedInfoDB",
Version: 24,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// We drop the migrated tables from the deprecated database and VACUUM SQLite3
// in migration step 23 because if we were to keep that as part of step 22
// and an error occurred it would replay the entire migration but some tables

View File

@ -100,7 +100,7 @@ func TestMigrate(t *testing.T) {
tag := fmt.Sprintf("#%d - v%d", i, step.Version)
// run migration up to a specific version
err := migrations.TargetVersion(step.Version).Run(log.Named("migrate"))
err := migrations.TargetVersion(step.Version).Run(ctx, log.Named("migrate"))
require.NoError(t, err, tag)
// find the matching expected version