From 8b358ef365c45bc9124daf81882146e42966237e Mon Sep 17 00:00:00 2001 From: Isaac Hess Date: Thu, 19 Sep 2019 15:21:03 -0600 Subject: [PATCH] internal/dbutil/sqliteutil: Fix error handling, ensure connections are closed (#3078) * internal/dbutil/sqliteutil: Fix error handling, ensure connections are closed * internal/dbutil/sqliteutil: Separate function to handle conn * internal/dbutil/sqliteutil: Fix names --- internal/dbutil/sqliteutil/migrator.go | 98 +++++++++++---------- internal/dbutil/sqliteutil/migrator_test.go | 20 +++-- 2 files changed, 63 insertions(+), 55 deletions(-) diff --git a/internal/dbutil/sqliteutil/migrator.go b/internal/dbutil/sqliteutil/migrator.go index 1a6e9b9ae..edf59363e 100644 --- a/internal/dbutil/sqliteutil/migrator.go +++ b/internal/dbutil/sqliteutil/migrator.go @@ -12,121 +12,127 @@ import ( "github.com/zeebo/errs" ) +var ( + // ErrMigrateTables is error class for MigrateTables + ErrMigrateTables = errs.Class("migrate tables:") + + // ErrKeepTables is error class for MigrateTables + ErrKeepTables = errs.Class("keep tables:") +) + // MigrateTablesToDatabase copies the specified tables from srcDB into destDB. // All tables in destDB will be dropped other than those specified in // tablesToKeep. func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error { + err := backupDBs(ctx, srcDB, destDB) + if err != nil { + return ErrMigrateTables.Wrap(err) + } + + // Remove tables we don't want to keep from the cloned destination database. + return ErrMigrateTables.Wrap(KeepTables(ctx, destDB, tablesToKeep...)) +} + +func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error { // Retrieve the raw Sqlite3 driver connections for the src and dest so that // we can execute the backup API for a corruption safe clone. srcConn, err := srcDB.Conn(ctx) if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } + defer func() { + err = errs.Combine(err, ErrMigrateTables.Wrap(srcConn.Close())) + }() + destConn, err := destDB.Conn(ctx) if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } + defer func() { + err = errs.Combine(err, ErrMigrateTables.Wrap(destConn.Close())) + }() + // The references to the driver connections are only guaranteed to be valid // for the life of the callback so we must do the work within both callbacks. err = srcConn.Raw(func(srcDriverConn interface{}) error { srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn) if !ok { - return errs.New("unable to get database driver") + return ErrMigrateTables.New("unable to get database driver") } - err = destConn.Raw(func(destDriverConn interface{}) error { + err := destConn.Raw(func(destDriverConn interface{}) error { destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn) if !ok { - return errs.New("unable to get database driver") + return ErrMigrateTables.New("unable to get database driver") } - err = backup(ctx, srcSqliteConn, destSqliteConn) - if err != nil { - return errs.New("unable to backup database") - } - return nil + return ErrMigrateTables.Wrap(backupConns(ctx, srcSqliteConn, destSqliteConn)) }) if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } return nil }) - if err != nil { - return errs.Wrap(err) - } - - if err := srcConn.Close(); err != nil { - return errs.Wrap(err) - } - if err := destConn.Close(); err != nil { - return errs.Wrap(err) - } - - // Remove tables we don't want to keep from the cloned destination database. - err = KeepTables(ctx, destDB, tablesToKeep...) - if err != nil { - return errs.Wrap(err) - } - return nil + return ErrMigrateTables.Wrap(err) } -// backup executes the sqlite3 backup process that safely ensures that no other +// backupConns executes the sqlite3 backup process that safely ensures that no other // connections to the database accidentally corrupt the source or destination. -func backup(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error { +func backupConns(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error { // "main" represents the main (ie not "temp") database in sqlite3, which is // the database we want to backup, and the appropriate dest in the destDB backup, err := destDB.Backup("main", sourceDB, "main") if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } isDone, err := backup.Step(0) if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } if isDone { - return errs.New("Backup is done") + return ErrMigrateTables.New("Backup is done") } // Check that the page count and remaining values are reasonable. initialPageCount := backup.PageCount() if initialPageCount <= 0 { - return errs.New("initialPageCount invalid") + return ErrMigrateTables.New("initialPageCount invalid") } initialRemaining := backup.Remaining() if initialRemaining <= 0 { - return errs.New("initialRemaining invalid") + return ErrMigrateTables.New("initialRemaining invalid") } if initialRemaining != initialPageCount { - return errs.New("initialRemaining != initialPageCount") + return ErrMigrateTables.New("initialRemaining != initialPageCount") } // Step -1 is used to copy the entire source database to the destination. isDone, err = backup.Step(-1) if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } if !isDone { - return errs.New("Backup not done") + return ErrMigrateTables.New("Backup not done") } // Check that the page count and remaining values are reasonable. finalPageCount := backup.PageCount() if finalPageCount != initialPageCount { - return errs.New("finalPageCount != initialPageCount") + return ErrMigrateTables.New("finalPageCount != initialPageCount") } finalRemaining := backup.Remaining() if finalRemaining != 0 { - return errs.New("finalRemaining invalid") + return ErrMigrateTables.New("finalRemaining invalid") } // Finish the backup. err = backup.Finish() if err != nil { - return errs.Wrap(err) + return ErrMigrateTables.Wrap(err) } return nil } @@ -136,7 +142,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error { // Get a list of tables excluding sqlite3 system tables. rows, err := db.Query("SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%';") if err != nil { - return errs.Wrap(err) + return ErrKeepTables.Wrap(err) } // Collect a list of the tables. We must do this because we can't do DDL @@ -146,13 +152,13 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error { var tableName string err = rows.Scan(&tableName) if err != nil { - return errs.Combine(err, rows.Close()) + return errs.Combine(err, ErrKeepTables.Wrap(rows.Close())) } tables = append(tables, tableName) } err = rows.Close() if err != nil { - return errs.Wrap(err) + return ErrKeepTables.Wrap(err) } // Loop over the list of tables and decide which ones to keep and which to drop. @@ -161,7 +167,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error { // Drop tables we aren't told to keep in the destination database. _, err = db.Exec(fmt.Sprintf("DROP TABLE %s;", tableName)) if err != nil { - return errs.Wrap(err) + return ErrKeepTables.Wrap(err) } } } @@ -170,7 +176,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error { // data will not actually be reclaimed until the db has been closed. _, err = db.Exec("VACUUM;") if err != nil { - return errs.Wrap(err) + return ErrKeepTables.Wrap(err) } return nil } diff --git a/internal/dbutil/sqliteutil/migrator_test.go b/internal/dbutil/sqliteutil/migrator_test.go index 0a448edd8..f548fc421 100644 --- a/internal/dbutil/sqliteutil/migrator_test.go +++ b/internal/dbutil/sqliteutil/migrator_test.go @@ -4,7 +4,6 @@ package sqliteutil_test import ( - "context" "database/sql" "testing" @@ -12,17 +11,17 @@ import ( "github.com/stretchr/testify/require" "storj.io/storj/internal/dbutil/sqliteutil" + "storj.io/storj/internal/testcontext" ) func TestMigrateTablesToDatabase(t *testing.T) { - ctx := context.Background() - srcDB := newMemDB(t) - destDB := newMemDB(t) + ctx := testcontext.New(t) + defer ctx.Cleanup() - defer func() { - require.NoError(t, srcDB.Close()) - require.NoError(t, destDB.Close()) - }() + srcDB := newMemDB(t) + defer ctx.Check(srcDB.Close) + destDB := newMemDB(t) + defer ctx.Check(srcDB.Close) query := ` CREATE TABLE bobby_jones(I Int); @@ -50,8 +49,11 @@ func TestMigrateTablesToDatabase(t *testing.T) { } func TestKeepTables(t *testing.T) { - ctx := context.Background() + ctx := testcontext.New(t) + defer ctx.Cleanup() + db := newMemDB(t) + defer ctx.Check(db.Close) table1SQL := ` CREATE TABLE table_one(I int);