private/dbutil/sqliteutil: add ctx argument
Change-Id: If1caa9cde746817e62cae32a152eeec81959129c
This commit is contained in:
parent
bcc23f6869
commit
c7b846589e
@ -4,6 +4,7 @@
|
|||||||
package sqliteutil
|
package sqliteutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// LoadSchemaFromSQL inserts script into connstr and loads schema.
|
// LoadSchemaFromSQL inserts script into connstr and loads schema.
|
||||||
func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
|
func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema, err error) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
@ -26,11 +27,11 @@ func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
|
|||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return QuerySchema(db)
|
return QuerySchema(ctx, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSnapshotFromSQL inserts script into connstr and loads schema.
|
// LoadSnapshotFromSQL inserts script into connstr and loads schema.
|
||||||
func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapshot, err error) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
@ -42,7 +43,7 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
|||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
snapshot, err := QuerySnapshot(db)
|
snapshot, err := QuerySnapshot(ctx, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
}
|
}
|
||||||
@ -52,13 +53,13 @@ func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QuerySnapshot loads snapshot from database
|
// QuerySnapshot loads snapshot from database
|
||||||
func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
func QuerySnapshot(ctx context.Context, db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
||||||
schema, err := QuerySchema(db)
|
schema, err := QuerySchema(ctx, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := QueryData(db, schema)
|
data, err := QueryData(ctx, db, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(err)
|
return nil, errs.Wrap(err)
|
||||||
}
|
}
|
||||||
@ -71,7 +72,7 @@ func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryData loads all data from tables
|
// QueryData loads all data from tables
|
||||||
func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
|
func QueryData(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
|
||||||
return dbschema.QueryData(db, schema, func(columnName string) string {
|
return dbschema.QueryData(db, schema, func(columnName string) string {
|
||||||
quoted := strconv.Quote(columnName)
|
quoted := strconv.Quote(columnName)
|
||||||
return `quote(` + quoted + `) as ` + quoted
|
return `quote(` + quoted + `) as ` + quoted
|
||||||
|
@ -35,13 +35,13 @@ func TestMigrateTablesToDatabase(t *testing.T) {
|
|||||||
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
|
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
destSchema, err := sqliteutil.QuerySchema(destDB)
|
destSchema, err := sqliteutil.QuerySchema(ctx, destDB)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
destData, err := sqliteutil.QueryData(destDB, destSchema)
|
destData, err := sqliteutil.QueryData(ctx, destDB, destSchema)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(query)
|
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, query)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, snapshot.Schema, destSchema)
|
require.Equal(t, snapshot.Schema, destSchema)
|
||||||
@ -71,13 +71,13 @@ func TestKeepTables(t *testing.T) {
|
|||||||
err := sqliteutil.KeepTables(ctx, db, "table_one")
|
err := sqliteutil.KeepTables(ctx, db, "table_one")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
schema, err := sqliteutil.QuerySchema(db)
|
schema, err := sqliteutil.QuerySchema(ctx, db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
data, err := sqliteutil.QueryData(db, schema)
|
data, err := sqliteutil.QueryData(ctx, db, schema)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL)
|
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, table1SQL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, snapshot.Schema, schema)
|
require.Equal(t, snapshot.Schema, schema)
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package sqliteutil
|
package sqliteutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -19,7 +20,7 @@ type definition struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QuerySchema loads the schema from sqlite database.
|
// QuerySchema loads the schema from sqlite database.
|
||||||
func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, error) {
|
||||||
schema := &dbschema.Schema{}
|
schema := &dbschema.Schema{}
|
||||||
|
|
||||||
tableDefinitions := make([]*definition, 0)
|
tableDefinitions := make([]*definition, 0)
|
||||||
@ -54,12 +55,12 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = discoverTables(db, schema, tableDefinitions)
|
err = discoverTables(ctx, db, schema, tableDefinitions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = discoverIndexes(db, schema, indexDefinitions)
|
err = discoverIndexes(ctx, db, schema, indexDefinitions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -68,7 +69,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
|
|||||||
return schema, nil
|
return schema, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
|
func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
|
||||||
for _, definition := range tableDefinitions {
|
for _, definition := range tableDefinitions {
|
||||||
table := schema.EnsureTable(definition.name)
|
table := schema.EnsureTable(definition.name)
|
||||||
|
|
||||||
@ -148,7 +149,7 @@ func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitio
|
|||||||
return errs.Wrap(err)
|
return errs.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
|
func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
|
||||||
// TODO improve indexes discovery
|
// TODO improve indexes discovery
|
||||||
for _, definition := range indexDefinitions {
|
for _, definition := range indexDefinitions {
|
||||||
index := &dbschema.Index{
|
index := &dbschema.Index{
|
||||||
|
@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
|
|||||||
|
|
||||||
defer ctx.Check(db.Close)
|
defer ctx.Check(db.Close)
|
||||||
|
|
||||||
emptySchema, err := sqliteutil.QuerySchema(db)
|
emptySchema, err := sqliteutil.QuerySchema(ctx, db)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, &dbschema.Schema{}, emptySchema)
|
assert.Equal(t, &dbschema.Schema{}, emptySchema)
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func TestQuery(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
schema, err := sqliteutil.QuerySchema(db)
|
schema, err := sqliteutil.QuerySchema(ctx, db)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected := &dbschema.Schema{
|
expected := &dbschema.Schema{
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package storagenodedb_test
|
package storagenodedb_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@ -22,7 +23,7 @@ import (
|
|||||||
// insertNewData will insert any NewData from the MultiDBState into the
|
// insertNewData will insert any NewData from the MultiDBState into the
|
||||||
// appropriate rawDB. This prepares the rawDB for the test comparing schema and
|
// appropriate rawDB. This prepares the rawDB for the test comparing schema and
|
||||||
// data.
|
// data.
|
||||||
func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
|
func insertNewData(ctx context.Context, mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.DBContainer) error {
|
||||||
for dbName, dbState := range mdbs.DBStates {
|
for dbName, dbState := range mdbs.DBStates {
|
||||||
if dbState.NewData == "" {
|
if dbState.NewData == "" {
|
||||||
continue
|
continue
|
||||||
@ -42,10 +43,10 @@ func insertNewData(mdbs *testdata.MultiDBState, rawDBs map[string]storagenodedb.
|
|||||||
|
|
||||||
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's
|
// getSchemas queries the schema of each rawDB and returns a map of each rawDB's
|
||||||
// schema keyed by dbName
|
// schema keyed by dbName
|
||||||
func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
|
func getSchemas(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbschema.Schema, error) {
|
||||||
schemas := make(map[string]*dbschema.Schema)
|
schemas := make(map[string]*dbschema.Schema)
|
||||||
for dbName, rawDB := range rawDBs {
|
for dbName, rawDB := range rawDBs {
|
||||||
schema, err := sqliteutil.QuerySchema(rawDB.GetDB())
|
schema, err := sqliteutil.QuerySchema(ctx, rawDB.GetDB())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -60,10 +61,10 @@ func getSchemas(rawDBs map[string]storagenodedb.DBContainer) (map[string]*dbsche
|
|||||||
|
|
||||||
// getSchemas queries the data of each rawDB and returns a map of each rawDB's
|
// getSchemas queries the data of each rawDB and returns a map of each rawDB's
|
||||||
// data keyed by dbName
|
// data keyed by dbName
|
||||||
func getData(rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
func getData(ctx context.Context, rawDBs map[string]storagenodedb.DBContainer, schemas map[string]*dbschema.Schema) (map[string]*dbschema.Data, error) {
|
||||||
data := make(map[string]*dbschema.Data)
|
data := make(map[string]*dbschema.Data)
|
||||||
for dbName, rawDB := range rawDBs {
|
for dbName, rawDB := range rawDBs {
|
||||||
datum, err := sqliteutil.QueryData(rawDB.GetDB(), schemas[dbName])
|
datum, err := sqliteutil.QueryData(ctx, rawDB.GetDB(), schemas[dbName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -109,18 +110,18 @@ func TestMigrate(t *testing.T) {
|
|||||||
rawDBs := db.RawDatabases()
|
rawDBs := db.RawDatabases()
|
||||||
|
|
||||||
// insert data for new tables
|
// insert data for new tables
|
||||||
err = insertNewData(expected, rawDBs)
|
err = insertNewData(ctx, expected, rawDBs)
|
||||||
require.NoError(t, err, tag)
|
require.NoError(t, err, tag)
|
||||||
|
|
||||||
// load schema from database
|
// load schema from database
|
||||||
schemas, err := getSchemas(rawDBs)
|
schemas, err := getSchemas(ctx, rawDBs)
|
||||||
require.NoError(t, err, tag)
|
require.NoError(t, err, tag)
|
||||||
|
|
||||||
// load data from database
|
// load data from database
|
||||||
data, err := getData(rawDBs, schemas)
|
data, err := getData(ctx, rawDBs, schemas)
|
||||||
require.NoError(t, err, tag)
|
require.NoError(t, err, tag)
|
||||||
|
|
||||||
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(expected)
|
multiDBSnapshot, err := testdata.LoadMultiDBSnapshot(ctx, expected)
|
||||||
require.NoError(t, err, tag)
|
require.NoError(t, err, tag)
|
||||||
|
|
||||||
// verify schema and data for each db in the expected snapshot
|
// verify schema and data for each db in the expected snapshot
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package testdata
|
package testdata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"storj.io/storj/private/dbutil/dbschema"
|
"storj.io/storj/private/dbutil/dbschema"
|
||||||
@ -103,10 +104,10 @@ type DBSnapshot struct {
|
|||||||
|
|
||||||
// LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It
|
// LoadMultiDBSnapshot converts a MultiDBState into a MultiDBSnapshot. It
|
||||||
// executes the SQL and stores the shema and data.
|
// executes the SQL and stores the shema and data.
|
||||||
func LoadMultiDBSnapshot(multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
|
func LoadMultiDBSnapshot(ctx context.Context, multiDBState *MultiDBState) (*MultiDBSnapshot, error) {
|
||||||
multiDBSnapshot := NewMultiDBSnapshot()
|
multiDBSnapshot := NewMultiDBSnapshot()
|
||||||
for dbName, dbState := range multiDBState.DBStates {
|
for dbName, dbState := range multiDBState.DBStates {
|
||||||
snapshot, err := sqliteutil.LoadSnapshotFromSQL(fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
|
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, fmt.Sprintf("%s\n%s", dbState.SQL, dbState.NewData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user