private/migrate: add ctx argument

Change-Id: I3d65912d89261386413c494c7ed1576fed4dcaf4
This commit is contained in:
Egon Elbre 2020-01-13 15:44:55 +02:00
parent 24958bd7d3
commit ff267168c5
13 changed files with 63 additions and 62 deletions

View File

@ -74,7 +74,7 @@ func cmdAPIRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher on satellite api: ", err) zap.S().Warn("Failed to initialize telemetry batcher on satellite api: ", err)
} }
err = db.CheckVersion() err = db.CheckVersion(ctx)
if err != nil { if err != nil {
zap.S().Fatal("failed satellite database version check: ", err) zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err) return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -236,7 +236,7 @@ func cmdRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher: ", err) zap.S().Warn("Failed to initialize telemetry batcher: ", err)
} }
err = db.CheckVersion() err = db.CheckVersion(ctx)
if err != nil { if err != nil {
zap.S().Fatal("failed satellite database version check: ", err) zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err) return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -74,7 +74,7 @@ func cmdRepairerRun(cmd *cobra.Command, args []string) (err error) {
zap.S().Warn("Failed to initialize telemetry batcher on repairer: ", err) zap.S().Warn("Failed to initialize telemetry batcher on repairer: ", err)
} }
err = db.CheckVersion() err = db.CheckVersion(ctx)
if err != nil { if err != nil {
zap.S().Fatal("failed satellite database version check: ", err) zap.S().Fatal("failed satellite database version check: ", err)
return errs.New("Error checking version for satellitedb: %+v", err) return errs.New("Error checking version for satellitedb: %+v", err)

View File

@ -14,12 +14,12 @@ import (
var Error = errs.Class("migrate") var Error = errs.Class("migrate")
// Create with a previous schema check // Create with a previous schema check
func Create(identifier string, db DBX) error { func Create(ctx context.Context, identifier string, db DBX) error {
// is this necessary? it's not immediately obvious why we roll back the transaction // is this necessary? it's not immediately obvious why we roll back the transaction
// when the schemas match. // when the schemas match.
justRollbackPlease := errs.Class("only used to tell WithTx to do a rollback") justRollbackPlease := errs.Class("only used to tell WithTx to do a rollback")
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) (err error) { err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) (err error) {
schema := db.Schema() schema := db.Schema()
_, err = tx.Exec(db.Rebind(`CREATE TABLE IF NOT EXISTS table_schemas (id text, schemaText text);`)) _, err = tx.Exec(db.Rebind(`CREATE TABLE IF NOT EXISTS table_schemas (id text, schemaText text);`))

View File

@ -30,19 +30,19 @@ func TestCreate_Sqlite(t *testing.T) {
defer func() { assert.NoError(t, db.Close()) }() defer func() { assert.NoError(t, db.Close()) }()
// should create table // should create table
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text)"}) err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
require.NoError(t, err) require.NoError(t, err)
// shouldn't create a new table // shouldn't create a new table
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text)"}) err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text)"})
require.NoError(t, err) require.NoError(t, err)
// should fail, because schema changed // should fail, because schema changed
err = migrate.Create("example", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"}) err = migrate.Create(ctx, "example", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
require.Error(t, err) require.Error(t, err)
// should fail, because of trying to CREATE TABLE with same name // should fail, because of trying to CREATE TABLE with same name
err = migrate.Create("conflict", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"}) err = migrate.Create(ctx, "conflict", &sqliteDB{db, "CREATE TABLE example_table (id text, version int)"})
require.Error(t, err) require.Error(t, err)
} }
@ -74,19 +74,19 @@ func testCreateGeneric(ctx *testcontext.Context, t *testing.T, connStr string) {
defer func() { assert.NoError(t, db.Close()) }() defer func() { assert.NoError(t, db.Close()) }()
// should create table // should create table
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"}) err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err) require.NoError(t, err)
// shouldn't create a new table // shouldn't create a new table
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"}) err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text)"})
require.NoError(t, err) require.NoError(t, err)
// should fail, because schema changed // should fail, because schema changed
err = migrate.Create("example", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"}) err = migrate.Create(ctx, "example", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err) require.Error(t, err)
// should fail, because of trying to CREATE TABLE with same name // should fail, because of trying to CREATE TABLE with same name
err = migrate.Create("conflict", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"}) err = migrate.Create(ctx, "conflict", &postgresDB{db.DB, "CREATE TABLE example_table (id text, version integer)"})
require.Error(t, err) require.Error(t, err)
} }

View File

@ -65,7 +65,7 @@ type Step struct {
// Action is something that needs to be done // Action is something that needs to be done
type Action interface { type Action interface {
Run(log *zap.Logger, db DB, tx *sql.Tx) error Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
} }
// TargetVersion returns migration with steps upto specified version // TargetVersion returns migration with steps upto specified version
@ -101,9 +101,9 @@ func (migration *Migration) ValidateSteps() error {
} }
// ValidateVersions checks that the version of the migration matches the state of the database // ValidateVersions checks that the version of the migration matches the state of the database
func (migration *Migration) ValidateVersions(log *zap.Logger) error { func (migration *Migration) ValidateVersions(ctx context.Context, log *zap.Logger) error {
for _, step := range migration.Steps { for _, step := range migration.Steps {
dbVersion, err := migration.getLatestVersion(log, step.DB) dbVersion, err := migration.getLatestVersion(ctx, log, step.DB)
if err != nil { if err != nil {
return ErrValidateVersionQuery.Wrap(err) return ErrValidateVersionQuery.Wrap(err)
} }
@ -124,7 +124,7 @@ func (migration *Migration) ValidateVersions(log *zap.Logger) error {
} }
// Run runs the migration steps // Run runs the migration steps
func (migration *Migration) Run(log *zap.Logger) error { func (migration *Migration) Run(ctx context.Context, log *zap.Logger) error {
err := migration.ValidTableName() err := migration.ValidTableName()
if err != nil { if err != nil {
return err return err
@ -142,12 +142,12 @@ func (migration *Migration) Run(log *zap.Logger) error {
return Error.New("step.DB is nil for step %d", step.Version) return Error.New("step.DB is nil for step %d", step.Version)
} }
err = migration.ensureVersionTable(log, step.DB) err = migration.ensureVersionTable(ctx, log, step.DB)
if err != nil { if err != nil {
return Error.New("creating version table failed: %v", err) return Error.New("creating version table failed: %v", err)
} }
version, err := migration.getLatestVersion(log, step.DB) version, err := migration.getLatestVersion(ctx, log, step.DB)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
@ -164,13 +164,13 @@ func (migration *Migration) Run(log *zap.Logger) error {
stepLog.Info(step.Description) stepLog.Info(step.Description)
} }
err = WithTx(context.Background(), step.DB, func(ctx context.Context, tx *sql.Tx) error { err = WithTx(ctx, step.DB, func(ctx context.Context, tx *sql.Tx) error {
err = step.Action.Run(stepLog, step.DB, tx) err = step.Action.Run(ctx, stepLog, step.DB, tx)
if err != nil { if err != nil {
return err return err
} }
err = migration.addVersion(tx, step.DB, step.Version) err = migration.addVersion(ctx, tx, step.DB, step.Version)
if err != nil { if err != nil {
return err return err
} }
@ -196,8 +196,8 @@ func (migration *Migration) Run(log *zap.Logger) error {
} }
// createVersionTable creates a new version table // createVersionTable creates a new version table
func (migration *Migration) ensureVersionTable(log *zap.Logger, db DB) error { func (migration *Migration) ensureVersionTable(ctx context.Context, log *zap.Logger, db DB) error {
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error { err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(rebind(db, `CREATE TABLE IF NOT EXISTS `+migration.Table+` (version int, commited_at text)`)) //nolint:misspell _, err := tx.Exec(rebind(db, `CREATE TABLE IF NOT EXISTS `+migration.Table+` (version int, commited_at text)`)) //nolint:misspell
return err return err
}) })
@ -205,9 +205,9 @@ func (migration *Migration) ensureVersionTable(log *zap.Logger, db DB) error {
} }
// getLatestVersion finds the latest version table // getLatestVersion finds the latest version table
func (migration *Migration) getLatestVersion(log *zap.Logger, db DB) (int, error) { func (migration *Migration) getLatestVersion(ctx context.Context, log *zap.Logger, db DB) (int, error) {
var version sql.NullInt64 var version sql.NullInt64
err := WithTx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error { err := WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow(rebind(db, `SELECT MAX(version) FROM `+migration.Table)).Scan(&version) err := tx.QueryRow(rebind(db, `SELECT MAX(version) FROM `+migration.Table)).Scan(&version)
if err == sql.ErrNoRows || !version.Valid { if err == sql.ErrNoRows || !version.Valid {
version.Int64 = -1 version.Int64 = -1
@ -220,7 +220,7 @@ func (migration *Migration) getLatestVersion(log *zap.Logger, db DB) (int, error
} }
// addVersion adds information about a new migration // addVersion adds information about a new migration
func (migration *Migration) addVersion(tx *sql.Tx, db DB, version int) error { func (migration *Migration) addVersion(ctx context.Context, tx *sql.Tx, db DB, version int) error {
_, err := tx.Exec(rebind(db, ` _, err := tx.Exec(rebind(db, `
INSERT INTO `+migration.Table+` (version, commited_at) VALUES (?, ?)`), //nolint:misspell INSERT INTO `+migration.Table+` (version, commited_at) VALUES (?, ?)`), //nolint:misspell
version, time.Now().String(), version, time.Now().String(),
@ -229,19 +229,19 @@ func (migration *Migration) addVersion(tx *sql.Tx, db DB, version int) error {
} }
// CurrentVersion finds the latest version for the db // CurrentVersion finds the latest version for the db
func (migration *Migration) CurrentVersion(log *zap.Logger, db DB) (int, error) { func (migration *Migration) CurrentVersion(ctx context.Context, log *zap.Logger, db DB) (int, error) {
err := migration.ensureVersionTable(log, db) err := migration.ensureVersionTable(ctx, log, db)
if err != nil { if err != nil {
return -1, Error.Wrap(err) return -1, Error.Wrap(err)
} }
return migration.getLatestVersion(log, db) return migration.getLatestVersion(ctx, log, db)
} }
// SQL statements that are executed on the database // SQL statements that are executed on the database
type SQL []string type SQL []string
// Run runs the SQL statements // Run runs the SQL statements
func (sql SQL) Run(log *zap.Logger, db DB, tx *sql.Tx) (err error) { func (sql SQL) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) (err error) {
for _, query := range sql { for _, query := range sql {
_, err := tx.Exec(rebind(db, query)) _, err := tx.Exec(rebind(db, query))
if err != nil { if err != nil {
@ -252,9 +252,9 @@ func (sql SQL) Run(log *zap.Logger, db DB, tx *sql.Tx) (err error) {
} }
// Func is an arbitrary operation // Func is an arbitrary operation
type Func func(log *zap.Logger, db DB, tx *sql.Tx) error type Func func(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error
// Run runs the migration // Run runs the migration
func (fn Func) Run(log *zap.Logger, db DB, tx *sql.Tx) error { func (fn Func) Run(ctx context.Context, log *zap.Logger, db DB, tx *sql.Tx) error {
return fn(log, db, tx) return fn(ctx, log, db, tx)
} }

View File

@ -4,6 +4,7 @@
package migrate_test package migrate_test
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -93,21 +94,21 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db *sql.DB, testDB m
DB: testDB, DB: testDB,
Description: "Move files", Description: "Move files",
Version: 2, Version: 2,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(_ context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return os.Rename(ctx.File("alpha.txt"), ctx.File("beta.txt")) return os.Rename(ctx.File("alpha.txt"), ctx.File("beta.txt"))
}), }),
}, },
}, },
} }
dbVersion, err := m.CurrentVersion(nil, testDB) dbVersion, err := m.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, dbVersion, -1) assert.Equal(t, dbVersion, -1)
err = m.Run(zap.NewNop()) err = m.Run(ctx, zap.NewNop())
assert.NoError(t, err) assert.NoError(t, err)
dbVersion, err = m.CurrentVersion(nil, testDB) dbVersion, err = m.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, dbVersion, 2) assert.Equal(t, dbVersion, 2)
@ -120,7 +121,7 @@ func basicMigration(ctx *testcontext.Context, t *testing.T, db *sql.DB, testDB m
}, },
}, },
} }
dbVersion, err = m2.CurrentVersion(nil, testDB) dbVersion, err = m2.CurrentVersion(ctx, nil, testDB)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, dbVersion, 2) assert.Equal(t, dbVersion, 2)
@ -181,7 +182,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB, DB: testDB,
Description: "Step 1", Description: "Step 1",
Version: 1, Version: 1,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++ steps++
return nil return nil
}), }),
@ -190,7 +191,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB, DB: testDB,
Description: "Step 2", Description: "Step 2",
Version: 2, Version: 2,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++ steps++
return nil return nil
}), }),
@ -198,7 +199,7 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
}, },
} }
err := m.Run(zap.NewNop()) err := m.Run(ctx, zap.NewNop())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, steps) assert.Equal(t, 2, steps)
@ -206,12 +207,12 @@ func multipleMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB, DB: testDB,
Description: "Step 3", Description: "Step 3",
Version: 3, Version: 3,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
steps++ steps++
return nil return nil
}), }),
}) })
err = m.Run(zap.NewNop()) err = m.Run(ctx, zap.NewNop())
assert.NoError(t, err) assert.NoError(t, err)
var version int var version int
@ -256,14 +257,14 @@ func failedMigration(t *testing.T, db *sql.DB, testDB migrate.DB) {
DB: testDB, DB: testDB,
Description: "Step 1", Description: "Step 1",
Version: 1, Version: 1,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
return fmt.Errorf("migration failed") return fmt.Errorf("migration failed")
}), }),
}, },
}, },
} }
err := m.Run(zap.NewNop()) err := m.Run(ctx, zap.NewNop())
require.Error(t, err, "migration failed") require.Error(t, err, "migration failed")
var version sql.NullInt64 var version sql.NullInt64

View File

@ -49,7 +49,7 @@ type DB interface {
// CreateTables initializes the database // CreateTables initializes the database
CreateTables(ctx context.Context) error CreateTables(ctx context.Context) error
// CheckVersion checks the database is the correct version // CheckVersion checks the database is the correct version
CheckVersion() error CheckVersion(ctx context.Context) error
// Close closes the database // Close closes the database
Close() error Close() error

View File

@ -62,7 +62,7 @@ func (db *satelliteDB) CreateTables(ctx context.Context) error {
// since we merged migration steps 0-69, the current db version should never be // since we merged migration steps 0-69, the current db version should never be
// less than 69 unless the migration hasn't run yet // less than 69 unless the migration hasn't run yet
const minDBVersion = 69 const minDBVersion = 69
dbVersion, err := migration.CurrentVersion(db.log, db.DB) dbVersion, err := migration.CurrentVersion(ctx, db.log, db.DB)
if err != nil { if err != nil {
return errs.New("error current version: %+v", err) return errs.New("error current version: %+v", err)
} }
@ -72,18 +72,18 @@ func (db *satelliteDB) CreateTables(ctx context.Context) error {
) )
} }
return migration.Run(db.log.Named("migrate")) return migration.Run(ctx, db.log.Named("migrate"))
default: default:
return migrate.Create("database", db.DB) return migrate.Create(ctx, "database", db.DB)
} }
} }
// CheckVersion confirms the database is at the desired version // CheckVersion confirms the database is at the desired version
func (db *satelliteDB) CheckVersion() error { func (db *satelliteDB) CheckVersion(ctx context.Context) error {
switch db.implementation { switch db.implementation {
case dbutil.Postgres, dbutil.Cockroach: case dbutil.Postgres, dbutil.Cockroach:
migration := db.PostgresMigration() migration := db.PostgresMigration()
return migration.ValidateVersions(db.log) return migration.ValidateVersions(ctx, db.log)
default: default:
return nil return nil

View File

@ -175,7 +175,7 @@ func pgMigrateTest(t *testing.T, connStr string) {
tag := fmt.Sprintf("#%d - v%d", i, step.Version) tag := fmt.Sprintf("#%d - v%d", i, step.Version)
// run migration up to a specific version // run migration up to a specific version
err := migrations.TargetVersion(step.Version).Run(log.Named("migrate")) err := migrations.TargetVersion(step.Version).Run(ctx, log.Named("migrate"))
require.NoError(t, err, tag) require.NoError(t, err, tag)
// find the matching expected version // find the matching expected version

View File

@ -210,13 +210,13 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
}) })
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error { err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
responses, err = db.processOrdersInTx(requests, storageNodeID, time.Now(), tx.Tx) responses, err = db.processOrdersInTx(ctx, requests, storageNodeID, time.Now(), tx.Tx)
return err return err
}) })
return responses, errs.Wrap(err) return responses, errs.Wrap(err)
} }
func (db *ordersDB) processOrdersInTx(requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx *sql.Tx) (responses []*orders.ProcessOrderResponse, err error) { func (db *ordersDB) processOrdersInTx(ctx context.Context, requests []*orders.ProcessOrderRequest, storageNodeID storj.NodeID, now time.Time, tx *sql.Tx) (responses []*orders.ProcessOrderResponse, err error) {
now = now.UTC() now = now.UTC()
intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())

View File

@ -288,7 +288,7 @@ func (db *DB) filepathFromDBName(dbName string) string {
// CreateTables creates any necessary tables. // CreateTables creates any necessary tables.
func (db *DB) CreateTables(ctx context.Context) error { func (db *DB) CreateTables(ctx context.Context) error {
migration := db.Migration(ctx) migration := db.Migration(ctx)
return migration.Run(db.log.Named("migration")) return migration.Run(ctx, db.log.Named("migration"))
} }
// Close closes any resources. // Close closes any resources.
@ -666,7 +666,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB, DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from trash data", Description: "Free Storagenodes from trash data",
Version: 13, Version: 13,
Action: migrate.Func(func(log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "blob/ukfu6bhbboxilvt7jrwlqk7y2tapb5d2r2tsmj2sjxvw5qaaaaaa")) // us-central1 err := os.RemoveAll(filepath.Join(db.dbDirectory, "blob/ukfu6bhbboxilvt7jrwlqk7y2tapb5d2r2tsmj2sjxvw5qaaaaaa")) // us-central1
if err != nil { if err != nil {
log.Sugar().Debug(err) log.Sugar().Debug(err)
@ -691,7 +691,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB, DB: db.deprecatedInfoDB,
Description: "Free Storagenodes from orphaned tmp data", Description: "Free Storagenodes from orphaned tmp data",
Version: 14, Version: 14,
Action: migrate.Func(func(log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, mgdb migrate.DB, tx *sql.Tx) error {
err := os.RemoveAll(filepath.Join(db.dbDirectory, "tmp")) err := os.RemoveAll(filepath.Join(db.dbDirectory, "tmp"))
if err != nil { if err != nil {
log.Sugar().Debug(err) log.Sugar().Debug(err)
@ -832,7 +832,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB, DB: db.deprecatedInfoDB,
Description: "Vacuum info db", Description: "Vacuum info db",
Version: 22, Version: 22,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
_, err := db.deprecatedInfoDB.GetDB().Exec("VACUUM;") _, err := db.deprecatedInfoDB.GetDB().Exec("VACUUM;")
return err return err
}), }),
@ -841,7 +841,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB, DB: db.deprecatedInfoDB,
Description: "Split into multiple sqlite databases", Description: "Split into multiple sqlite databases",
Version: 23, Version: 23,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// Migrate all the tables to new database files. // Migrate all the tables to new database files.
if err := db.migrateToDB(ctx, BandwidthDBName, "bandwidth_usage", "bandwidth_usage_rollups"); err != nil { if err := db.migrateToDB(ctx, BandwidthDBName, "bandwidth_usage", "bandwidth_usage_rollups"); err != nil {
return ErrDatabase.Wrap(err) return ErrDatabase.Wrap(err)
@ -878,7 +878,7 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
DB: db.deprecatedInfoDB, DB: db.deprecatedInfoDB,
Description: "Drop unneeded tables in deprecatedInfoDB", Description: "Drop unneeded tables in deprecatedInfoDB",
Version: 24, Version: 24,
Action: migrate.Func(func(log *zap.Logger, _ migrate.DB, tx *sql.Tx) error { Action: migrate.Func(func(ctx context.Context, log *zap.Logger, _ migrate.DB, tx *sql.Tx) error {
// We drop the migrated tables from the deprecated database and VACUUM SQLite3 // We drop the migrated tables from the deprecated database and VACUUM SQLite3
// in migration step 23 because if we were to keep that as part of step 22 // in migration step 23 because if we were to keep that as part of step 22
// and an error occurred it would replay the entire migration but some tables // and an error occurred it would replay the entire migration but some tables

View File

@ -100,7 +100,7 @@ func TestMigrate(t *testing.T) {
tag := fmt.Sprintf("#%d - v%d", i, step.Version) tag := fmt.Sprintf("#%d - v%d", i, step.Version)
// run migration up to a specific version // run migration up to a specific version
err := migrations.TargetVersion(step.Version).Run(log.Named("migrate")) err := migrations.TargetVersion(step.Version).Run(ctx, log.Named("migrate"))
require.NoError(t, err, tag) require.NoError(t, err, tag)
// find the matching expected version // find the matching expected version