private/dbutil/sqliteutil: add ctx argument

Change-Id: If1caa9cde746817e62cae32a152eeec81959129c
This commit is contained in:
Egon Elbre 2020-01-13 15:03:30 +02:00
parent bcc23f6869
commit c7b846589e
6 changed files with 36 additions and 32 deletions

View File

@ -4,6 +4,7 @@
package sqliteutil package sqliteutil
import ( import (
"context"
"database/sql" "database/sql"
"strconv" "strconv"
@ -14,7 +15,7 @@ import (
) )
// LoadSchemaFromSQL inserts script into connstr and loads schema. // LoadSchemaFromSQL inserts script into connstr and loads schema.
func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) { func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema, err error) {
db, err := sql.Open("sqlite3", ":memory:") db, err := sql.Open("sqlite3", ":memory:")
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
@ -26,11 +27,11 @@ func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
return QuerySchema(db) return QuerySchema(ctx, db)
} }
// LoadSnapshotFromSQL inserts script into connstr and loads schema. // LoadSnapshotFromSQL inserts script into connstr and loads schema.
func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) { func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapshot, err error) {
db, err := sql.Open("sqlite3", ":memory:") db, err := sql.Open("sqlite3", ":memory:")
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
@ -42,7 +43,7 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
snapshot, err := QuerySnapshot(db) snapshot, err := QuerySnapshot(ctx, db)
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
@ -52,13 +53,13 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
} }
// QuerySnapshot loads snapshot from database // QuerySnapshot loads snapshot from database
func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) { func QuerySnapshot(ctx context.Context, db dbschema.Queryer) (*dbschema.Snapshot, error) {
schema, err := QuerySchema(db) schema, err := QuerySchema(ctx, db)
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
data, err := QueryData(db, schema) data, err := QueryData(ctx, db, schema)
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err)
} }
@ -71,7 +72,7 @@ func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
} }
// QueryData loads all data from tables // QueryData loads all data from tables
func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) { func QueryData(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
return dbschema.QueryData(db, schema, func(columnName string) string { return dbschema.QueryData(db, schema, func(columnName string) string {
quoted := strconv.Quote(columnName) quoted := strconv.Quote(columnName)
return `quote(` + quoted + `) as ` + quoted return `quote(` + quoted + `) as ` + quoted

View File

@ -35,13 +35,13 @@ func TestMigrateTablesToDatabase(t *testing.T) {
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones") err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
require.NoError(t, err) require.NoError(t, err)
destSchema, err := sqliteutil.QuerySchema(destDB) destSchema, err := sqliteutil.QuerySchema(ctx, destDB)
require.NoError(t, err) require.NoError(t, err)
destData, err := sqliteutil.QueryData(destDB, destSchema) destData, err := sqliteutil.QueryData(ctx, destDB, destSchema)
require.NoError(t, err) require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(query) snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, snapshot.Schema, destSchema) require.Equal(t, snapshot.Schema, destSchema)
@ -71,13 +71,13 @@ func TestKeepTables(t *testing.T) {
err := sqliteutil.KeepTables(ctx, db, "table_one") err := sqliteutil.KeepTables(ctx, db, "table_one")
require.NoError(t, err) require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db) schema, err := sqliteutil.QuerySchema(ctx, db)
require.NoError(t, err) require.NoError(t, err)
data, err := sqliteutil.QueryData(db, schema) data, err := sqliteutil.QueryData(ctx, db, schema)
require.NoError(t, err) require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL) snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, table1SQL)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, snapshot.Schema, schema) require.Equal(t, snapshot.Schema, schema)

View File

@ -4,6 +4,7 @@
package sqliteutil package sqliteutil
import ( import (
"context"
"database/sql" "database/sql"
"regexp" "regexp"
"strings" "strings"
@ -19,7 +20,7 @@ type definition struct {
} }
// QuerySchema loads the schema from sqlite database. // QuerySchema loads the schema from sqlite database.
func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, error) {
schema := &dbschema.Schema{} schema := &dbschema.Schema{}
tableDefinitions := make([]*definition, 0) tableDefinitions := make([]*definition, 0)
@ -54,12 +55,12 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
return nil, err return nil, err
} }
err = discoverTables(db, schema, tableDefinitions) err = discoverTables(ctx, db, schema, tableDefinitions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = discoverIndexes(db, schema, indexDefinitions) err = discoverIndexes(ctx, db, schema, indexDefinitions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -68,7 +69,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
return schema, nil return schema, nil
} }
func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) { func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
for _, definition := range tableDefinitions { for _, definition := range tableDefinitions {
table := schema.EnsureTable(definition.name) table := schema.EnsureTable(definition.name)
@ -148,7 +149,7 @@ func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitio
return errs.Wrap(err) return errs.Wrap(err)
} }
func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) { func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
// TODO improve indexes discovery // TODO improve indexes discovery
for _, definition := range indexDefinitions { for _, definition := range indexDefinitions {
index := &dbschema.Index{ index := &dbschema.Index{

View File

@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
defer ctx.Check(db.Close) defer ctx.Check(db.Close)
emptySchema, err := sqliteutil.QuerySchema(db) emptySchema, err := sqliteutil.QuerySchema(ctx, db)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, &dbschema.Schema{}, emptySchema) assert.Equal(t, &dbschema.Schema{}, emptySchema)
@ -52,7 +52,7 @@ func TestQuery(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db) schema, err := sqliteutil.QuerySchema(ctx, db)
assert.NoError(t, err) assert.NoError(t, err)
expected := &dbschema.Schema{ expected := &dbschema.Schema{

View File

@ -4,6 +4,7 @@
package storagenodedb_test package storagenodedb_test
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"testing" "testing"
@ -22,7 +23,7 @@ import (
// insertNewData will insert any NewData from the MultiDBState into the // insertNewData will insert any NewData from the MultiDBState into the
// appropriate rawDB. This prepares the rawDB for the test comparing schema and // appropriate rawDB. This prepares the rawDB for the test comparing schema and
// data. // data.
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error { func insertNewData(ctx context.Context, mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
for dbName, dbState := range mdbs.DBStates { for dbName, dbState := range mdbs.DBStates {
if dbState.NewData == "" { if dbState.NewData == "" {
continue continue
@ -42,10 +43,10 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's // getSchemas queries the schema of each rawDB and returns a map of each rawDB's
// schema keyed by dbName // schema keyed by dbName
func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) { func getSchemas(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
schemas := make(map[string]*dbschema.Schema) schemas := make(map[string]*dbschema.Schema)
for dbName, rawDB := range rawDBs { for dbName, rawDB := range rawDBs {
schema, err := sqliteutil.QuerySchema(rawDB.GetDB()) schema, err := sqliteutil.QuerySchema(ctx, rawDB.GetDB())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,10 +61,10 @@ func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbsche
// getSchemas queries the data of each rawDB and returns a map of each rawDB's // getSchemas queries the data of each rawDB and returns a map of each rawDB's
// data keyed by dbName // data keyed by dbName
func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) { func getData(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
data := make(map[string]*dbschema.Data) data := make(map[string]*dbschema.Data)
for dbName, rawDB := range rawDBs { for dbName, rawDB := range rawDBs {
datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName]) datum, err := sqliteutil.QueryData(ctx, rawDB.GetDB(), schemas[dbName])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -109,18 +110,18 @@ func TestMigrate(t *testing.T) {
rawDBs := db.RawDatabases() rawDBs := db.RawDatabases()
// insert data for new tables // insert data for new tables
err = insertNewData(expected, rawDBs) err = insertNewData(ctx, expected, rawDBs)
require.NoError(t, err, tag) require.NoError(t, err, tag)
// load schema from database // load schema from database
schemas, err := getSchemas(rawDBs) schemas, err := getSchemas(ctx, rawDBs)
require.NoError(t, err, tag) require.NoError(t, err, tag)
// load data from database // load data from database
data, err := getData(rawDBs, schemas) data, err := getData(ctx, rawDBs, schemas)
require.NoError(t, err, tag) require.NoError(t, err, tag)
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(expected) multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(ctx, expected)
require.NoError(t, err, tag) require.NoError(t, err, tag)
// verify schema and data for each db in the expected snapshot // verify schema and data for each db in the expected snapshot

View File

@ -4,6 +4,7 @@
package testdata package testdata
import ( import (
"context"
"fmt" "fmt"
"storj.io/storj/private/dbutil/dbschema" "storj.io/storj/private/dbutil/dbschema"
@ -103,10 +104,10 @@ type DBSnapshot struct {
// LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It // LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It
// executes the SQL and stores the shema and data. // executes the SQL and stores the shema and data.
func LoadMultiDBSnapshot(multiDBState *MultiDBState) (*MultiDBSnapshot, error) { func LoadMultiDBSnapshot(ctx context.Context, multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
multiDBSnapshot := NewMultiDBSnapshot() multiDBSnapshot := NewMultiDBSnapshot()
for dbName, dbState := range multiDBState.DBStates { for dbName, dbState := range multiDBState.DBStates {
snapshot, err := sqliteutil.LoadSnapshotFromSQL(fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData)) snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
if err != nil { if err != nil {
return nil, err return nil, err
} }