From c7b846589e95dd8ca3d253306599325bbd104979 Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Mon, 13 Jan 2020 15:03:30 +0200 Subject: [PATCH] private/dbutil/sqliteutil: add ctx argument Change-Id: If1caa9cde746817e62cae32a152eeec81959129c --- private/dbutil/sqliteutil/db.go | 17 +++++++++-------- private/dbutil/sqliteutil/migrator_test.go | 12 ++++++------ private/dbutil/sqliteutil/query.go | 11 ++++++----- private/dbutil/sqliteutil/query_test.go | 4 ++-- storagenode/storagenodedb/migrations_test.go | 19 ++++++++++--------- .../storagenodedb/testdata/multidbsnapshot.go | 5 +++-- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/private/dbutil/sqliteutil/db.go b/private/dbutil/sqliteutil/db.go index e422ea2a8..18f8564df 100644 --- a/private/dbutil/sqliteutil/db.go +++ b/private/dbutil/sqliteutil/db.go @@ -4,6 +4,7 @@ package sqliteutil import ( + "context" "database/sql" "strconv" @@ -14,7 +15,7 @@ import ( ) // LoadSchemaFromSQL inserts script into connstr and loads schema. -func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) { +func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema, err error) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { return nil, errs.Wrap(err) @@ -26,11 +27,11 @@ func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) { return nil, errs.Wrap(err) } - return QuerySchema(db) + return QuerySchema(ctx, db) } // LoadSnapshotFromSQL inserts script into connstr and loads schema. -func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) { +func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapshot, err error) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { return nil, errs.Wrap(err) @@ -42,7 +43,7 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) { return nil, errs.Wrap(err) } - snapshot, err := QuerySnapshot(db) + snapshot, err := QuerySnapshot(ctx, db) if err != nil { return nil, errs.Wrap(err) } @@ -52,13 +53,13 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) { } // QuerySnapshot loads snapshot from database -func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) { - schema, err := QuerySchema(db) +func QuerySnapshot(ctx context.Context, db dbschema.Queryer) (*dbschema.Snapshot, error) { + schema, err := QuerySchema(ctx, db) if err != nil { return nil, errs.Wrap(err) } - data, err := QueryData(db, schema) + data, err := QueryData(ctx, db, schema) if err != nil { return nil, errs.Wrap(err) } @@ -71,7 +72,7 @@ func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) { } // QueryData loads all data from tables -func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) { +func QueryData(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) { return dbschema.QueryData(db, schema, func(columnName string) string { quoted := strconv.Quote(columnName) return `quote(` + quoted + `) as ` + quoted diff --git a/private/dbutil/sqliteutil/migrator_test.go b/private/dbutil/sqliteutil/migrator_test.go index 6d08b768e..d7bb1c7a9 100644 --- a/private/dbutil/sqliteutil/migrator_test.go +++ b/private/dbutil/sqliteutil/migrator_test.go @@ -35,13 +35,13 @@ func TestMigrateTablesToDatabase(t *testing.T) { err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones") require.NoError(t, err) - destSchema, err := sqliteutil.QuerySchema(destDB) + destSchema, err := sqliteutil.QuerySchema(ctx, destDB) require.NoError(t, err) - destData, err := sqliteutil.QueryData(destDB, destSchema) + destData, err := sqliteutil.QueryData(ctx, destDB, destSchema) require.NoError(t, err) - snapshot, err := sqliteutil.LoadSnapshotFromSQL(query) + snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, query) require.NoError(t, err) require.Equal(t, snapshot.Schema, destSchema) @@ -71,13 +71,13 @@ func TestKeepTables(t *testing.T) { err := sqliteutil.KeepTables(ctx, db, "table_one") require.NoError(t, err) - schema, err := sqliteutil.QuerySchema(db) + schema, err := sqliteutil.QuerySchema(ctx, db) require.NoError(t, err) - data, err := sqliteutil.QueryData(db, schema) + data, err := sqliteutil.QueryData(ctx, db, schema) require.NoError(t, err) - snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL) + snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, table1SQL) require.NoError(t, err) require.Equal(t, snapshot.Schema, schema) diff --git a/private/dbutil/sqliteutil/query.go b/private/dbutil/sqliteutil/query.go index 08760b737..2bf4f23b6 100644 --- a/private/dbutil/sqliteutil/query.go +++ b/private/dbutil/sqliteutil/query.go @@ -4,6 +4,7 @@ package sqliteutil import ( + "context" "database/sql" "regexp" "strings" @@ -19,7 +20,7 @@ type definition struct { } // QuerySchema loads the schema from sqlite database. -func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { +func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, error) { schema := &dbschema.Schema{} tableDefinitions := make([]*definition, 0) @@ -54,12 +55,12 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { return nil, err } - err = discoverTables(db, schema, tableDefinitions) + err = discoverTables(ctx, db, schema, tableDefinitions) if err != nil { return nil, err } - err = discoverIndexes(db, schema, indexDefinitions) + err = discoverIndexes(ctx, db, schema, indexDefinitions) if err != nil { return nil, err } @@ -68,7 +69,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { return schema, nil } -func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) { +func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) { for _, definition := range tableDefinitions { table := schema.EnsureTable(definition.name) @@ -148,7 +149,7 @@ func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitio return errs.Wrap(err) } -func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) { +func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) { // TODO improve indexes discovery for _, definition := range indexDefinitions { index := &dbschema.Index{ diff --git a/private/dbutil/sqliteutil/query_test.go b/private/dbutil/sqliteutil/query_test.go index d57a55af0..be1482262 100644 --- a/private/dbutil/sqliteutil/query_test.go +++ b/private/dbutil/sqliteutil/query_test.go @@ -26,7 +26,7 @@ func TestQuery(t *testing.T) { defer ctx.Check(db.Close) - emptySchema, err := sqliteutil.QuerySchema(db) + emptySchema, err := sqliteutil.QuerySchema(ctx, db) assert.NoError(t, err) assert.Equal(t, &dbschema.Schema{}, emptySchema) @@ -52,7 +52,7 @@ func TestQuery(t *testing.T) { require.NoError(t, err) - schema, err := sqliteutil.QuerySchema(db) + schema, err := sqliteutil.QuerySchema(ctx, db) assert.NoError(t, err) expected := &dbschema.Schema{ diff --git a/storagenode/storagenodedb/migrations_test.go b/storagenode/storagenodedb/migrations_test.go index 3a6271296..c74e6fc63 100644 --- a/storagenode/storagenodedb/migrations_test.go +++ b/storagenode/storagenodedb/migrations_test.go @@ -4,6 +4,7 @@ package storagenodedb_test import ( + "context" "fmt" "path/filepath" "testing" @@ -22,7 +23,7 @@ import ( // insertNewData will insert any NewData from the MultiDBState into the // appropriate rawDB. This prepares the rawDB for the test comparing schema and // data. -func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error { +func insertNewData(ctx context.Context, mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error { for dbName, dbState := range mdbs.DBStates { if dbState.NewData == "" { continue @@ -42,10 +43,10 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb. // getSchemas queries the schema of each rawDB and returns a map of each rawDB's // schema keyed by dbName -func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) { +func getSchemas(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) { schemas := make(map[string]*dbschema.Schema) for dbName, rawDB := range rawDBs { - schema, err := sqliteutil.QuerySchema(rawDB.GetDB()) + schema, err := sqliteutil.QuerySchema(ctx, rawDB.GetDB()) if err != nil { return nil, err } @@ -60,10 +61,10 @@ func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbsche // getSchemas queries the data of each rawDB and returns a map of each rawDB's // data keyed by dbName -func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) { +func getData(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) { data := make(map[string]*dbschema.Data) for dbName, rawDB := range rawDBs { - datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName]) + datum, err := sqliteutil.QueryData(ctx, rawDB.GetDB(), schemas[dbName]) if err != nil { return nil, err } @@ -109,18 +110,18 @@ func TestMigrate(t *testing.T) { rawDBs := db.RawDatabases() // insert data for new tables - err = insertNewData(expected, rawDBs) + err = insertNewData(ctx, expected, rawDBs) require.NoError(t, err, tag) // load schema from database - schemas, err := getSchemas(rawDBs) + schemas, err := getSchemas(ctx, rawDBs) require.NoError(t, err, tag) // load data from database - data, err := getData(rawDBs, schemas) + data, err := getData(ctx, rawDBs, schemas) require.NoError(t, err, tag) - multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(expected) + multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(ctx, expected) require.NoError(t, err, tag) // verify schema and data for each db in the expected snapshot diff --git a/storagenode/storagenodedb/testdata/multidbsnapshot.go b/storagenode/storagenodedb/testdata/multidbsnapshot.go index adcb6194f..7b7574019 100644 --- a/storagenode/storagenodedb/testdata/multidbsnapshot.go +++ b/storagenode/storagenodedb/testdata/multidbsnapshot.go @@ -4,6 +4,7 @@ package testdata import ( + "context" "fmt" "storj.io/storj/private/dbutil/dbschema" @@ -103,10 +104,10 @@ type DBSnapshot struct { // LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It // executes the SQL and stores the shema and data. -func LoadMultiDBSnapshot(multiDBState *MultiDBState) (*MultiDBSnapshot, error) { +func LoadMultiDBSnapshot(ctx context.Context, multiDBState *MultiDBState) (*MultiDBSnapshot, error) { multiDBSnapshot := NewMultiDBSnapshot() for dbName, dbState := range multiDBState.DBStates { - snapshot, err := sqliteutil.LoadSnapshotFromSQL(fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData)) + snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData)) if err != nil { return nil, err }