private/dbutil/sqliteutil: add ctx argument
Change-Id: If1caa9cde746817e62cae32a152eeec81959129c
This commit is contained in:
parent
bcc23f6869
commit
c7b846589e
@ -4,6 +4,7 @@
|
||||
package sqliteutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
|
||||
@ -14,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// LoadSchemaFromSQL inserts script into connstr and loads schema.
|
||||
func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
|
||||
func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema, err error) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -26,11 +27,11 @@ func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
return QuerySchema(db)
|
||||
return QuerySchema(ctx, db)
|
||||
}
|
||||
|
||||
// LoadSnapshotFromSQL inserts script into connstr and loads schema.
|
||||
func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
||||
func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapshot, err error) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -42,7 +43,7 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
snapshot, err := QuerySnapshot(db)
|
||||
snapshot, err := QuerySnapshot(ctx, db)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
@ -52,13 +53,13 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
||||
}
|
||||
|
||||
// QuerySnapshot loads snapshot from database
|
||||
func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
||||
schema, err := QuerySchema(db)
|
||||
func QuerySnapshot(ctx context.Context, db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
||||
schema, err := QuerySchema(ctx, db)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
data, err := QueryData(db, schema)
|
||||
data, err := QueryData(ctx, db, schema)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
@ -71,7 +72,7 @@ func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
||||
}
|
||||
|
||||
// QueryData loads all data from tables
|
||||
func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
|
||||
func QueryData(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
|
||||
return dbschema.QueryData(db, schema, func(columnName string) string {
|
||||
quoted := strconv.Quote(columnName)
|
||||
return `quote(` + quoted + `) as ` + quoted
|
||||
|
@ -35,13 +35,13 @@ func TestMigrateTablesToDatabase(t *testing.T) {
|
||||
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
|
||||
require.NoError(t, err)
|
||||
|
||||
destSchema, err := sqliteutil.QuerySchema(destDB)
|
||||
destSchema, err := sqliteutil.QuerySchema(ctx, destDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
destData, err := sqliteutil.QueryData(destDB, destSchema)
|
||||
destData, err := sqliteutil.QueryData(ctx, destDB, destSchema)
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(query)
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, snapshot.Schema, destSchema)
|
||||
@ -71,13 +71,13 @@ func TestKeepTables(t *testing.T) {
|
||||
err := sqliteutil.KeepTables(ctx, db, "table_one")
|
||||
require.NoError(t, err)
|
||||
|
||||
schema, err := sqliteutil.QuerySchema(db)
|
||||
schema, err := sqliteutil.QuerySchema(ctx, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := sqliteutil.QueryData(db, schema)
|
||||
data, err := sqliteutil.QueryData(ctx, db, schema)
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL)
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, table1SQL)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, snapshot.Schema, schema)
|
||||
|
@ -4,6 +4,7 @@
|
||||
package sqliteutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -19,7 +20,7 @@ type definition struct {
|
||||
}
|
||||
|
||||
// QuerySchema loads the schema from sqlite database.
|
||||
func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
||||
func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, error) {
|
||||
schema := &dbschema.Schema{}
|
||||
|
||||
tableDefinitions := make([]*definition, 0)
|
||||
@ -54,12 +55,12 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = discoverTables(db, schema, tableDefinitions)
|
||||
err = discoverTables(ctx, db, schema, tableDefinitions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = discoverIndexes(db, schema, indexDefinitions)
|
||||
err = discoverIndexes(ctx, db, schema, indexDefinitions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -68,7 +69,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
|
||||
func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
|
||||
for _, definition := range tableDefinitions {
|
||||
table := schema.EnsureTable(definition.name)
|
||||
|
||||
@ -148,7 +149,7 @@ func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitio
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
|
||||
func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
|
||||
// TODO improve indexes discovery
|
||||
for _, definition := range indexDefinitions {
|
||||
index := &dbschema.Index{
|
||||
|
@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
|
||||
|
||||
defer ctx.Check(db.Close)
|
||||
|
||||
emptySchema, err := sqliteutil.QuerySchema(db)
|
||||
emptySchema, err := sqliteutil.QuerySchema(ctx, db)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &dbschema.Schema{}, emptySchema)
|
||||
|
||||
@ -52,7 +52,7 @@ func TestQuery(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
schema, err := sqliteutil.QuerySchema(db)
|
||||
schema, err := sqliteutil.QuerySchema(ctx, db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := &dbschema.Schema{
|
||||
|
@ -4,6 +4,7 @@
|
||||
package storagenodedb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@ -22,7 +23,7 @@ import (
|
||||
// insertNewData will insert any NewData from the MultiDBState into the
|
||||
// appropriate rawDB. This prepares the rawDB for the test comparing schema and
|
||||
// data.
|
||||
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
|
||||
func insertNewData(ctx context.Context, mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
|
||||
for dbName, dbState := range mdbs.DBStates {
|
||||
if dbState.NewData == "" {
|
||||
continue
|
||||
@ -42,10 +43,10 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.
|
||||
|
||||
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's
|
||||
// schema keyed by dbName
|
||||
func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
|
||||
func getSchemas(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
|
||||
schemas := make(map[string]*dbschema.Schema)
|
||||
for dbName, rawDB := range rawDBs {
|
||||
schema, err := sqliteutil.QuerySchema(rawDB.GetDB())
|
||||
schema, err := sqliteutil.QuerySchema(ctx, rawDB.GetDB())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -60,10 +61,10 @@ func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbsche
|
||||
|
||||
// getSchemas queries the data of each rawDB and returns a map of each rawDB's
|
||||
// data keyed by dbName
|
||||
func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
||||
func getData(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
||||
data := make(map[string]*dbschema.Data)
|
||||
for dbName, rawDB := range rawDBs {
|
||||
datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName])
|
||||
datum, err := sqliteutil.QueryData(ctx, rawDB.GetDB(), schemas[dbName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -109,18 +110,18 @@ func TestMigrate(t *testing.T) {
|
||||
rawDBs := db.RawDatabases()
|
||||
|
||||
// insert data for new tables
|
||||
err = insertNewData(expected, rawDBs)
|
||||
err = insertNewData(ctx, expected, rawDBs)
|
||||
require.NoError(t, err, tag)
|
||||
|
||||
// load schema from database
|
||||
schemas, err := getSchemas(rawDBs)
|
||||
schemas, err := getSchemas(ctx, rawDBs)
|
||||
require.NoError(t, err, tag)
|
||||
|
||||
// load data from database
|
||||
data, err := getData(rawDBs, schemas)
|
||||
data, err := getData(ctx, rawDBs, schemas)
|
||||
require.NoError(t, err, tag)
|
||||
|
||||
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(expected)
|
||||
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(ctx, expected)
|
||||
require.NoError(t, err, tag)
|
||||
|
||||
// verify schema and data for each db in the expected snapshot
|
||||
|
@ -4,6 +4,7 @@
|
||||
package testdata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"storj.io/storj/private/dbutil/dbschema"
|
||||
@ -103,10 +104,10 @@ type DBSnapshot struct {
|
||||
|
||||
// LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It
|
||||
// executes the SQL and stores the shema and data.
|
||||
func LoadMultiDBSnapshot(multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
|
||||
func LoadMultiDBSnapshot(ctx context.Context, multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
|
||||
multiDBSnapshot := NewMultiDBSnapshot()
|
||||
for dbName, dbState := range multiDBState.DBStates {
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
|
||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user