private/dbutil/sqlutil: use context in queries

Change-Id: Icb92daa483d13e6d57013f3917571d476126bfd2
This commit is contained in:
Egon Elbre 2020-01-14 13:29:25 +02:00 committed by Jennifer Li Johnson
parent df9e53ea0b
commit 64f056bee4
6 changed files with 19 additions and 14 deletions

View File

@ -5,6 +5,7 @@
package dbschema
import (
"context"
"database/sql"
"sort"
)
@ -13,6 +14,8 @@ import (
type Queryer interface {
// Query executes a query that returns rows, typically a SELECT.
Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryContext executes a query that returns rows, typically a SELECT.
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// Schema is the database structure.

View File

@ -22,7 +22,7 @@ func LoadSchemaFromSQL(ctx context.Context, script string) (_ *dbschema.Schema,
}
defer func() { err = errs.Combine(err, db.Close()) }()
_, err = db.Exec(script)
_, err = db.ExecContext(ctx, script)
if err != nil {
return nil, errs.Wrap(err)
}
@ -38,7 +38,7 @@ func LoadSnapshotFromSQL(ctx context.Context, script string) (_ *dbschema.Snapsh
}
defer func() { err = errs.Combine(err, db.Close()) }()
_, err = db.Exec(script)
_, err = db.ExecContext(ctx, script)
if err != nil {
return nil, errs.Wrap(err)
}

View File

@ -21,6 +21,7 @@ type DB interface {
migrate.DB
Conn(ctx context.Context) (*sql.Conn, error)
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
var (
@ -173,7 +174,7 @@ func KeepTables(ctx context.Context, db DB, tablesToKeep ...string) (err error)
// data will not actually be reclaimed until the db has been closed.
// We don't include this in the above transaction because
// you can't VACUUM within a transaction with SQLite3.
_, err = db.Exec("VACUUM;")
_, err = db.ExecContext(ctx, "VACUUM;")
if err != nil {
return ErrKeepTables.Wrap(err)
}

View File

@ -4,6 +4,7 @@
package sqliteutil_test
import (
"context"
"database/sql"
"testing"
@ -28,9 +29,9 @@ func TestMigrateTablesToDatabase(t *testing.T) {
INSERT INTO bobby_jones VALUES (1);
`
execSQL(t, srcDB, query)
execSQL(ctx, t, srcDB, query)
// This table should be removed after migration
execSQL(t, srcDB, "CREATE TABLE what(I Int);")
execSQL(ctx, t, srcDB, "CREATE TABLE what(I Int);")
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
require.NoError(t, err)
@ -65,8 +66,8 @@ func TestKeepTables(t *testing.T) {
INSERT INTO table_two VALUES(2);
`
execSQL(t, db, table1SQL)
execSQL(t, db, table2SQL)
execSQL(ctx, t, db, table1SQL)
execSQL(ctx, t, db, table2SQL)
err := sqliteutil.KeepTables(ctx, db, "table_one")
require.NoError(t, err)
@ -84,8 +85,8 @@ func TestKeepTables(t *testing.T) {
require.Equal(t, snapshot.Data, data)
}
func execSQL(t *testing.T, db *sql.DB, query string, args ...interface{}) {
_, err := db.Exec(query, args...)
func execSQL(ctx context.Context, t *testing.T, db *sql.DB, query string, args ...interface{}) {
_, err := db.ExecContext(ctx, query, args...)
require.NoError(t, err)
}

View File

@ -28,7 +28,7 @@ func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, er
// find tables and indexes
err := func() error {
rows, err := db.Query(`
rows, err := db.QueryContext(ctx, `
SELECT name, type, sql FROM sqlite_master WHERE sql NOT NULL AND name NOT LIKE 'sqlite_%'
`)
if err != nil {
@ -73,7 +73,7 @@ func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.S
for _, definition := range tableDefinitions {
table := schema.EnsureTable(definition.name)
tableRows, err := db.Query(`PRAGMA table_info(` + definition.name + `)`)
tableRows, err := db.QueryContext(ctx, `PRAGMA table_info(`+definition.name+`)`)
if err != nil {
return errs.Wrap(err)
}
@ -115,7 +115,7 @@ func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.S
table.Unique = append(table.Unique, columns)
}
keysRows, err := db.Query(`PRAGMA foreign_key_list(` + definition.name + `)`)
keysRows, err := db.QueryContext(ctx, `PRAGMA foreign_key_list(`+definition.name+`)`)
if err != nil {
return errs.Wrap(err)
}
@ -157,7 +157,7 @@ func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.
}
schema.Indexes = append(schema.Indexes, index)
indexRows, err := db.Query(`PRAGMA index_info(` + definition.name + `)`)
indexRows, err := db.QueryContext(ctx, `PRAGMA index_info(`+definition.name+`)`)
if err != nil {
return errs.Wrap(err)
}

View File

@ -30,7 +30,7 @@ func TestQuery(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, &dbschema.Schema{}, emptySchema)
_, err = db.Exec(`
_, err = db.ExecContext(ctx, `
CREATE TABLE users (
a integer NOT NULL,
b integer NOT NULL,