internal/dbutil/sqliteutil: add MigrateTablesToDatabase (#3064)

* internal/dbutil/sqliteutil: add migrator

* internal/dbutil/sqliteutil: Fix errors and tablename
This commit is contained in:
Isaac Hess 2019-09-17 15:42:40 -06:00 committed by GitHub
parent 7c203b4884
commit fd20fa38c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 279 additions and 0 deletions

View File

@ -0,0 +1,185 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package sqliteutil
import (
"context"
"database/sql"
"fmt"
"github.com/mattn/go-sqlite3"
"github.com/zeebo/errs"
)
// MigrateTablesToDatabase copies the specified tables from srcDB into destDB.
// All tables in destDB will be dropped other than those specified in
// tablesToKeep.
func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error {
// Retrieve the raw Sqlite3 driver connections for the src and dest so that
// we can execute the backup API for a corruption safe clone.
srcConn, err := srcDB.Conn(ctx)
if err != nil {
return errs.Wrap(err)
}
destConn, err := destDB.Conn(ctx)
if err != nil {
return errs.Wrap(err)
}
// The references to the driver connections are only guaranteed to be valid
// for the life of the callback so we must do the work within both callbacks.
err = srcConn.Raw(func(srcDriverConn interface{}) error {
srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn)
if !ok {
return errs.New("unable to get database driver")
}
err = destConn.Raw(func(destDriverConn interface{}) error {
destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn)
if !ok {
return errs.New("unable to get database driver")
}
err = backup(ctx, srcSqliteConn, destSqliteConn)
if err != nil {
return errs.New("unable to backup database")
}
return nil
})
if err != nil {
return errs.Wrap(err)
}
return nil
})
if err != nil {
return errs.Wrap(err)
}
if err := srcConn.Close(); err != nil {
return errs.Wrap(err)
}
if err := destConn.Close(); err != nil {
return errs.Wrap(err)
}
// Remove tables we don't want to keep from the cloned destination database.
err = KeepTables(ctx, destDB, tablesToKeep...)
if err != nil {
return errs.Wrap(err)
}
return nil
}
// backup executes the sqlite3 backup process that safely ensures that no other
// connections to the database accidentally corrupt the source or destination.
func backup(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error {
// "main" represents the main (ie not "temp") database in sqlite3, which is
// the database we want to backup, and the appropriate dest in the destDB
backup, err := destDB.Backup("main", sourceDB, "main")
if err != nil {
return errs.Wrap(err)
}
isDone, err := backup.Step(0)
if err != nil {
return errs.Wrap(err)
}
if isDone {
return errs.New("Backup is done")
}
// Check that the page count and remaining values are reasonable.
initialPageCount := backup.PageCount()
if initialPageCount <= 0 {
return errs.New("initialPageCount invalid")
}
initialRemaining := backup.Remaining()
if initialRemaining <= 0 {
return errs.New("initialRemaining invalid")
}
if initialRemaining != initialPageCount {
return errs.New("initialRemaining != initialPageCount")
}
// Step -1 is used to copy the entire source database to the destination.
isDone, err = backup.Step(-1)
if err != nil {
return errs.Wrap(err)
}
if !isDone {
return errs.New("Backup not done")
}
// Check that the page count and remaining values are reasonable.
finalPageCount := backup.PageCount()
if finalPageCount != initialPageCount {
return errs.New("finalPageCount != initialPageCount")
}
finalRemaining := backup.Remaining()
if finalRemaining != 0 {
return errs.New("finalRemaining invalid")
}
// Finish the backup.
err = backup.Finish()
if err != nil {
return errs.Wrap(err)
}
return nil
}
// KeepTables drops all the tables except the specified tables to keep.
func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error {
// Get a list of tables excluding sqlite3 system tables.
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%';")
if err != nil {
return errs.Wrap(err)
}
// Collect a list of the tables. We must do this because we can't do DDL
// statements like drop tables while a query result is open.
var tables []string
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return errs.Combine(err, rows.Close())
}
tables = append(tables, tableName)
}
err = rows.Close()
if err != nil {
return errs.Wrap(err)
}
// Loop over the list of tables and decide which ones to keep and which to drop.
for _, tableName := range tables {
if !tableToKeep(tableName, tablesToKeep) {
// Drop tables we aren't told to keep in the destination database.
_, err = db.Exec(fmt.Sprintf("DROP TABLE %s;", tableName))
if err != nil {
return errs.Wrap(err)
}
}
}
// VACUUM the database to reclaim the space used by the dropped tables. The
// data will not actually be reclaimed until the db has been closed.
_, err = db.Exec("VACUUM;")
if err != nil {
return errs.Wrap(err)
}
return nil
}
func tableToKeep(table string, tables []string) bool {
for _, t := range tables {
if t == table {
return true
}
}
return false
}

View File

@ -0,0 +1,94 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package sqliteutil_test
import (
"context"
"database/sql"
"testing"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
"storj.io/storj/internal/dbutil/sqliteutil"
)
func TestMigrateTablesToDatabase(t *testing.T) {
ctx := context.Background()
srcDB := newMemDB(t)
destDB := newMemDB(t)
defer func() {
require.NoError(t, srcDB.Close())
require.NoError(t, destDB.Close())
}()
query := `
CREATE TABLE bobby_jones(I Int);
INSERT INTO bobby_jones VALUES (1);
`
execSQL(t, srcDB, query)
// This table should be removed after migration
execSQL(t, srcDB, "CREATE TABLE what(I Int);")
err := sqliteutil.MigrateTablesToDatabase(ctx, srcDB, destDB, "bobby_jones")
require.NoError(t, err)
destSchema, err := sqliteutil.QuerySchema(destDB)
require.NoError(t, err)
destData, err := sqliteutil.QueryData(destDB, destSchema)
require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(query)
require.NoError(t, err)
require.Equal(t, snapshot.Schema, destSchema)
require.Equal(t, snapshot.Data, destData)
}
func TestKeepTables(t *testing.T) {
ctx := context.Background()
db := newMemDB(t)
table1SQL := `
CREATE TABLE table_one(I int);
INSERT INTO table_one VALUES(1);
`
table2SQL := `
CREATE TABLE table_two(I int);
INSERT INTO table_two VALUES(2);
`
execSQL(t, db, table1SQL)
execSQL(t, db, table2SQL)
err := sqliteutil.KeepTables(ctx, db, "table_one")
require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db)
require.NoError(t, err)
data, err := sqliteutil.QueryData(db, schema)
require.NoError(t, err)
snapshot, err := sqliteutil.LoadSnapshotFromSQL(table1SQL)
require.NoError(t, err)
require.Equal(t, snapshot.Schema, schema)
require.Equal(t, snapshot.Data, data)
}
func execSQL(t *testing.T, db *sql.DB, query string, args ...interface{}) {
_, err := db.Exec(query, args...)
require.NoError(t, err)
}
func newMemDB(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
return db
}