private/dbutil/sqliteutil: add ctx argument

Change-Id: If1caa9cde746817e62cae32a152eeec81959129c
This commit is contained in:
Egon Elbre 2020-01-13 15:03:30 +02:00
parent bcc23f6869
commit c7b846589e
6 changed files with 36 additions and 32 deletions

View File

@ -4,6 +4,7 @@
package sqliteutil
import (
"context"
"database/sql"
"strconv"
@ -14,7 +15,7 @@ import (
)
// LoadSchemaFromSQL inserts script into connstr and loads schema.
func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema, err error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, errs.Wrap(err)
@ -26,11 +27,11 @@ func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
return nil, errs.Wrap(err)
}
return QuerySchema(db)
return QuerySchema(ctx, db)
}
// LoadSnapshotFromSQL inserts script into connstr and loads schema.
func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapshot, err error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, errs.Wrap(err)
@ -42,7 +43,7 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
return nil, errs.Wrap(err)
}
snapshot, err := QuerySnapshot(db)
snapshot, err := QuerySnapshot(ctx, db)
if err != nil {
return nil, errs.Wrap(err)
}
@ -52,13 +53,13 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
}
// QuerySnapshot loads snapshot from database
func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
schema, err := QuerySchema(db)
func QuerySnapshot(ctx context.Context, db dbschema.Queryer) (*dbschema.Snapshot, error) {
schema, err := QuerySchema(ctx, db)
if err != nil {
return nil, errs.Wrap(err)
}
data, err := QueryData(db, schema)
data, err := QueryData(ctx, db, schema)
if err != nil {
return nil, errs.Wrap(err)
}
@ -71,7 +72,7 @@ func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
}
// QueryData loads all data from tables
func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
func QueryData(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
return dbschema.QueryData(db, schema, func(columnName string) string {
quoted := strconv.Quote(columnName)
return `quote(` + quoted + `) as ` + quoted

View File

@ -35,13 +35,13 @@ func TestMigrateTablesToDatabase(t *testing.T) {
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
require.NoError(t, err)
destSchema, err := sqliteutil.QuerySchema(destDB)
destSchema, err := sqliteutil.QuerySchema(ctx, destDB)
require.NoError(t, err)
destData, err := sqliteutil.QueryData(destDB, destSchema)
destData, err := sqliteutil.QueryData(ctx, destDB, destSchema)
require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(query)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, query)
require.NoError(t, err)
require.Equal(t, snapshot.Schema, destSchema)
@ -71,13 +71,13 @@ func TestKeepTables(t *testing.T) {
err := sqliteutil.KeepTables(ctx, db, "table_one")
require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db)
schema, err := sqliteutil.QuerySchema(ctx, db)
require.NoError(t, err)
data, err := sqliteutil.QueryData(db, schema)
data, err := sqliteutil.QueryData(ctx, db, schema)
require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, table1SQL)
require.NoError(t, err)
require.Equal(t, snapshot.Schema, schema)

View File

@ -4,6 +4,7 @@
package sqliteutil
import (
"context"
"database/sql"
"regexp"
"strings"
@ -19,7 +20,7 @@ type definition struct {
}
// QuerySchema loads the schema from sqlite database.
func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, error) {
schema := &dbschema.Schema{}
tableDefinitions := make([]*definition, 0)
@ -54,12 +55,12 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
return nil, err
}
err = discoverTables(db, schema, tableDefinitions)
err = discoverTables(ctx, db, schema, tableDefinitions)
if err != nil {
return nil, err
}
err = discoverIndexes(db, schema, indexDefinitions)
err = discoverIndexes(ctx, db, schema, indexDefinitions)
if err != nil {
return nil, err
}
@ -68,7 +69,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
return schema, nil
}
func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
for _, definition := range tableDefinitions {
table := schema.EnsureTable(definition.name)
@ -148,7 +149,7 @@ func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitio
return errs.Wrap(err)
}
func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
// TODO improve indexes discovery
for _, definition := range indexDefinitions {
index := &dbschema.Index{

View File

@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
defer ctx.Check(db.Close)
emptySchema, err := sqliteutil.QuerySchema(db)
emptySchema, err := sqliteutil.QuerySchema(ctx, db)
assert.NoError(t, err)
assert.Equal(t, &dbschema.Schema{}, emptySchema)
@ -52,7 +52,7 @@ func TestQuery(t *testing.T) {
require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db)
schema, err := sqliteutil.QuerySchema(ctx, db)
assert.NoError(t, err)
expected := &dbschema.Schema{

View File

@ -4,6 +4,7 @@
package storagenodedb_test
import (
"context"
"fmt"
"path/filepath"
"testing"
@ -22,7 +23,7 @@ import (
// insertNewData will insert any NewData from the MultiDBState into the
// appropriate rawDB. This prepares the rawDB for the test comparing schema and
// data.
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
func insertNewData(ctx context.Context, mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
for dbName, dbState := range mdbs.DBStates {
if dbState.NewData == "" {
continue
@ -42,10 +43,10 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's
// schema keyed by dbName
func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
func getSchemas(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
schemas := make(map[string]*dbschema.Schema)
for dbName, rawDB := range rawDBs {
schema, err := sqliteutil.QuerySchema(rawDB.GetDB())
schema, err := sqliteutil.QuerySchema(ctx, rawDB.GetDB())
if err != nil {
return nil, err
}
@ -60,10 +61,10 @@ func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbsche
// getSchemas queries the data of each rawDB and returns a map of each rawDB's
// data keyed by dbName
func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
func getData(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
data := make(map[string]*dbschema.Data)
for dbName, rawDB := range rawDBs {
datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName])
datum, err := sqliteutil.QueryData(ctx, rawDB.GetDB(), schemas[dbName])
if err != nil {
return nil, err
}
@ -109,18 +110,18 @@ func TestMigrate(t *testing.T) {
rawDBs := db.RawDatabases()
// insert data for new tables
err = insertNewData(expected, rawDBs)
err = insertNewData(ctx, expected, rawDBs)
require.NoError(t, err, tag)
// load schema from database
schemas, err := getSchemas(rawDBs)
schemas, err := getSchemas(ctx, rawDBs)
require.NoError(t, err, tag)
// load data from database
data, err := getData(rawDBs, schemas)
data, err := getData(ctx, rawDBs, schemas)
require.NoError(t, err, tag)
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(expected)
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(ctx, expected)
require.NoError(t, err, tag)
// verify schema and data for each db in the expected snapshot

View File

@ -4,6 +4,7 @@
package testdata
import (
"context"
"fmt"
"storj.io/storj/private/dbutil/dbschema"
@ -103,10 +104,10 @@ type DBSnapshot struct {
// LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It
// executes the SQL and stores the shema and data.
func LoadMultiDBSnapshot(multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
func LoadMultiDBSnapshot(ctx context.Context, multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
multiDBSnapshot := NewMultiDBSnapshot()
for dbName, dbState := range multiDBState.DBStates {
snapshot, err := sqliteutil.LoadSnapshotFromSQL(fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
if err != nil {
return nil, err
}