internal/dbutil/sqliteutil: Fix error handling, ensure connections are closed (#3078)
* internal/dbutil/sqliteutil: Fix error handling, ensure connections are closed * internal/dbutil/sqliteutil: Separate function to handle conn * internal/dbutil/sqliteutil: Fix names
This commit is contained in:
parent
53db517154
commit
8b358ef365
@ -12,121 +12,127 @@ import (
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMigrateTables is error class for MigrateTables
|
||||
ErrMigrateTables = errs.Class("migrate tables:")
|
||||
|
||||
// ErrKeepTables is error class for MigrateTables
|
||||
ErrKeepTables = errs.Class("keep tables:")
|
||||
)
|
||||
|
||||
// MigrateTablesToDatabase copies the specified tables from srcDB into destDB.
|
||||
// All tables in destDB will be dropped other than those specified in
|
||||
// tablesToKeep.
|
||||
func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error {
|
||||
err := backupDBs(ctx, srcDB, destDB)
|
||||
if err != nil {
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
// Remove tables we don't want to keep from the cloned destination database.
|
||||
return ErrMigrateTables.Wrap(KeepTables(ctx, destDB, tablesToKeep...))
|
||||
}
|
||||
|
||||
func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error {
|
||||
// Retrieve the raw Sqlite3 driver connections for the src and dest so that
|
||||
// we can execute the backup API for a corruption safe clone.
|
||||
srcConn, err := srcDB.Conn(ctx)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = errs.Combine(err, ErrMigrateTables.Wrap(srcConn.Close()))
|
||||
}()
|
||||
|
||||
destConn, err := destDB.Conn(ctx)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = errs.Combine(err, ErrMigrateTables.Wrap(destConn.Close()))
|
||||
}()
|
||||
|
||||
// The references to the driver connections are only guaranteed to be valid
|
||||
// for the life of the callback so we must do the work within both callbacks.
|
||||
err = srcConn.Raw(func(srcDriverConn interface{}) error {
|
||||
srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn)
|
||||
if !ok {
|
||||
return errs.New("unable to get database driver")
|
||||
return ErrMigrateTables.New("unable to get database driver")
|
||||
}
|
||||
|
||||
err = destConn.Raw(func(destDriverConn interface{}) error {
|
||||
err := destConn.Raw(func(destDriverConn interface{}) error {
|
||||
destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn)
|
||||
if !ok {
|
||||
return errs.New("unable to get database driver")
|
||||
return ErrMigrateTables.New("unable to get database driver")
|
||||
}
|
||||
|
||||
err = backup(ctx, srcSqliteConn, destSqliteConn)
|
||||
if err != nil {
|
||||
return errs.New("unable to backup database")
|
||||
}
|
||||
return nil
|
||||
return ErrMigrateTables.Wrap(backupConns(ctx, srcSqliteConn, destSqliteConn))
|
||||
})
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
if err := srcConn.Close(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
if err := destConn.Close(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
// Remove tables we don't want to keep from the cloned destination database.
|
||||
err = KeepTables(ctx, destDB, tablesToKeep...)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
// backup executes the sqlite3 backup process that safely ensures that no other
|
||||
// backupConns executes the sqlite3 backup process that safely ensures that no other
|
||||
// connections to the database accidentally corrupt the source or destination.
|
||||
func backup(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error {
|
||||
func backupConns(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error {
|
||||
// "main" represents the main (ie not "temp") database in sqlite3, which is
|
||||
// the database we want to backup, and the appropriate dest in the destDB
|
||||
backup, err := destDB.Backup("main", sourceDB, "main")
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
|
||||
isDone, err := backup.Step(0)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
if isDone {
|
||||
return errs.New("Backup is done")
|
||||
return ErrMigrateTables.New("Backup is done")
|
||||
}
|
||||
|
||||
// Check that the page count and remaining values are reasonable.
|
||||
initialPageCount := backup.PageCount()
|
||||
if initialPageCount <= 0 {
|
||||
return errs.New("initialPageCount invalid")
|
||||
return ErrMigrateTables.New("initialPageCount invalid")
|
||||
}
|
||||
initialRemaining := backup.Remaining()
|
||||
if initialRemaining <= 0 {
|
||||
return errs.New("initialRemaining invalid")
|
||||
return ErrMigrateTables.New("initialRemaining invalid")
|
||||
}
|
||||
if initialRemaining != initialPageCount {
|
||||
return errs.New("initialRemaining != initialPageCount")
|
||||
return ErrMigrateTables.New("initialRemaining != initialPageCount")
|
||||
}
|
||||
|
||||
// Step -1 is used to copy the entire source database to the destination.
|
||||
isDone, err = backup.Step(-1)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
if !isDone {
|
||||
return errs.New("Backup not done")
|
||||
return ErrMigrateTables.New("Backup not done")
|
||||
}
|
||||
|
||||
// Check that the page count and remaining values are reasonable.
|
||||
finalPageCount := backup.PageCount()
|
||||
if finalPageCount != initialPageCount {
|
||||
return errs.New("finalPageCount != initialPageCount")
|
||||
return ErrMigrateTables.New("finalPageCount != initialPageCount")
|
||||
}
|
||||
finalRemaining := backup.Remaining()
|
||||
if finalRemaining != 0 {
|
||||
return errs.New("finalRemaining invalid")
|
||||
return ErrMigrateTables.New("finalRemaining invalid")
|
||||
}
|
||||
|
||||
// Finish the backup.
|
||||
err = backup.Finish()
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrMigrateTables.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -136,7 +142,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error {
|
||||
// Get a list of tables excluding sqlite3 system tables.
|
||||
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%';")
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrKeepTables.Wrap(err)
|
||||
}
|
||||
|
||||
// Collect a list of the tables. We must do this because we can't do DDL
|
||||
@ -146,13 +152,13 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error {
|
||||
var tableName string
|
||||
err = rows.Scan(&tableName)
|
||||
if err != nil {
|
||||
return errs.Combine(err, rows.Close())
|
||||
return errs.Combine(err, ErrKeepTables.Wrap(rows.Close()))
|
||||
}
|
||||
tables = append(tables, tableName)
|
||||
}
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrKeepTables.Wrap(err)
|
||||
}
|
||||
|
||||
// Loop over the list of tables and decide which ones to keep and which to drop.
|
||||
@ -161,7 +167,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error {
|
||||
// Drop tables we aren't told to keep in the destination database.
|
||||
_, err = db.Exec(fmt.Sprintf("DROP TABLE %s;", tableName))
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrKeepTables.Wrap(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -170,7 +176,7 @@ func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) error {
|
||||
// data will not actually be reclaimed until the db has been closed.
|
||||
_, err = db.Exec("VACUUM;")
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
return ErrKeepTables.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
package sqliteutil_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
@ -12,17 +11,17 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/storj/internal/dbutil/sqliteutil"
|
||||
"storj.io/storj/internal/testcontext"
|
||||
)
|
||||
|
||||
func TestMigrateTablesToDatabase(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srcDB := newMemDB(t)
|
||||
destDB := newMemDB(t)
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, srcDB.Close())
|
||||
require.NoError(t, destDB.Close())
|
||||
}()
|
||||
srcDB := newMemDB(t)
|
||||
defer ctx.Check(srcDB.Close)
|
||||
destDB := newMemDB(t)
|
||||
defer ctx.Check(srcDB.Close)
|
||||
|
||||
query := `
|
||||
CREATE TABLE bobby_jones(I Int);
|
||||
@ -50,8 +49,11 @@ func TestMigrateTablesToDatabase(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestKeepTables(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
db := newMemDB(t)
|
||||
defer ctx.Check(db.Close)
|
||||
|
||||
table1SQL := `
|
||||
CREATE TABLE table_one(I int);
|
||||
|
Loading…
Reference in New Issue
Block a user