2019-09-17 22:42:40 +01:00
|
|
|
// Copyright (C) 2019 Storj Labs, Inc.
|
|
|
|
// See LICENSE for copying information.
|
|
|
|
|
|
|
|
package sqliteutil
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
|
|
|
|
"github.com/mattn/go-sqlite3"
|
|
|
|
"github.com/zeebo/errs"
|
|
|
|
)
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
var (
|
|
|
|
// ErrMigrateTables is error class for MigrateTables
|
|
|
|
ErrMigrateTables = errs.Class("migrate tables:")
|
|
|
|
|
|
|
|
// ErrKeepTables is error class for MigrateTables
|
|
|
|
ErrKeepTables = errs.Class("keep tables:")
|
|
|
|
)
|
|
|
|
|
2019-09-17 22:42:40 +01:00
|
|
|
// MigrateTablesToDatabase copies the specified tables from srcDB into destDB.
|
|
|
|
// All tables in destDB will be dropped other than those specified in
|
|
|
|
// tablesToKeep.
|
|
|
|
func MigrateTablesToDatabase(ctx context.Context, srcDB, destDB *sql.DB, tablesToKeep ...string) error {
|
2019-09-19 22:21:03 +01:00
|
|
|
err := backupDBs(ctx, srcDB, destDB)
|
|
|
|
if err != nil {
|
|
|
|
return ErrMigrateTables.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Remove tables we don't want to keep from the cloned destination database.
|
|
|
|
return ErrMigrateTables.Wrap(KeepTables(ctx, destDB, tablesToKeep...))
|
|
|
|
}
|
|
|
|
|
|
|
|
func backupDBs(ctx context.Context, srcDB, destDB *sql.DB) error {
|
2019-09-17 22:42:40 +01:00
|
|
|
// Retrieve the raw Sqlite3 driver connections for the src and dest so that
|
|
|
|
// we can execute the backup API for a corruption safe clone.
|
|
|
|
srcConn, err := srcDB.Conn(ctx)
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
defer func() {
|
|
|
|
err = errs.Combine(err, ErrMigrateTables.Wrap(srcConn.Close()))
|
|
|
|
}()
|
|
|
|
|
2019-09-17 22:42:40 +01:00
|
|
|
destConn, err := destDB.Conn(ctx)
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
defer func() {
|
|
|
|
err = errs.Combine(err, ErrMigrateTables.Wrap(destConn.Close()))
|
|
|
|
}()
|
|
|
|
|
2019-09-17 22:42:40 +01:00
|
|
|
// The references to the driver connections are only guaranteed to be valid
|
|
|
|
// for the life of the callback so we must do the work within both callbacks.
|
|
|
|
err = srcConn.Raw(func(srcDriverConn interface{}) error {
|
|
|
|
srcSqliteConn, ok := srcDriverConn.(*sqlite3.SQLiteConn)
|
|
|
|
if !ok {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("unable to get database driver")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
err := destConn.Raw(func(destDriverConn interface{}) error {
|
2019-09-17 22:42:40 +01:00
|
|
|
destSqliteConn, ok := destDriverConn.(*sqlite3.SQLiteConn)
|
|
|
|
if !ok {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("unable to get database driver")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(backupConns(ctx, srcSqliteConn, destSqliteConn))
|
2019-09-17 22:42:40 +01:00
|
|
|
})
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
})
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
2019-09-19 22:21:03 +01:00
|
|
|
// backupConns executes the sqlite3 backup process that safely ensures that no other
|
2019-09-17 22:42:40 +01:00
|
|
|
// connections to the database accidentally corrupt the source or destination.
|
2019-09-19 22:21:03 +01:00
|
|
|
func backupConns(ctx context.Context, sourceDB *sqlite3.SQLiteConn, destDB *sqlite3.SQLiteConn) error {
|
2019-09-17 22:42:40 +01:00
|
|
|
// "main" represents the main (ie not "temp") database in sqlite3, which is
|
|
|
|
// the database we want to backup, and the appropriate dest in the destDB
|
|
|
|
backup, err := destDB.Backup("main", sourceDB, "main")
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
isDone, err := backup.Step(0)
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
if isDone {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("Backup is done")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the page count and remaining values are reasonable.
|
|
|
|
initialPageCount := backup.PageCount()
|
|
|
|
if initialPageCount <= 0 {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("initialPageCount invalid")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
initialRemaining := backup.Remaining()
|
|
|
|
if initialRemaining <= 0 {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("initialRemaining invalid")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
if initialRemaining != initialPageCount {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("initialRemaining != initialPageCount")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Step -1 is used to copy the entire source database to the destination.
|
|
|
|
isDone, err = backup.Step(-1)
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
if !isDone {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("Backup not done")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the page count and remaining values are reasonable.
|
|
|
|
finalPageCount := backup.PageCount()
|
|
|
|
if finalPageCount != initialPageCount {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("finalPageCount != initialPageCount")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
finalRemaining := backup.Remaining()
|
|
|
|
if finalRemaining != 0 {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.New("finalRemaining invalid")
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Finish the backup.
|
|
|
|
err = backup.Finish()
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrMigrateTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// KeepTables drops all the tables except the specified tables to keep.
|
2019-09-23 20:36:46 +01:00
|
|
|
func KeepTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) {
|
|
|
|
err = dropTables(ctx, db, tablesToKeep...)
|
|
|
|
if err != nil {
|
|
|
|
return ErrKeepTables.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// VACUUM the database to reclaim the space used by the dropped tables. The
|
|
|
|
// data will not actually be reclaimed until the db has been closed.
|
|
|
|
// We don't include this in the above transaction because
|
|
|
|
// you can't VACUUM within a transaction with SQLite3.
|
|
|
|
_, err = db.Exec("VACUUM;")
|
|
|
|
if err != nil {
|
|
|
|
return ErrKeepTables.Wrap(err)
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// dropTables performs the table drops in a single transaction
|
|
|
|
func dropTables(ctx context.Context, db *sql.DB, tablesToKeep ...string) (err error) {
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
|
|
if err != nil {
|
|
|
|
return ErrKeepTables.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if err != nil {
|
|
|
|
err = ErrKeepTables.Wrap(errs.Combine(err, tx.Rollback()))
|
|
|
|
} else {
|
|
|
|
err = ErrKeepTables.Wrap(errs.Combine(err, tx.Commit()))
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2019-09-17 22:42:40 +01:00
|
|
|
// Get a list of tables excluding sqlite3 system tables.
|
2019-09-23 20:36:46 +01:00
|
|
|
rows, err := tx.QueryContext(ctx, "SELECT name FROM sqlite_master WHERE type ='table' AND name NOT LIKE 'sqlite_%';")
|
2019-09-17 22:42:40 +01:00
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrKeepTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Collect a list of the tables. We must do this because we can't do DDL
|
|
|
|
// statements like drop tables while a query result is open.
|
|
|
|
var tables []string
|
|
|
|
for rows.Next() {
|
|
|
|
var tableName string
|
|
|
|
err = rows.Scan(&tableName)
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return errs.Combine(err, ErrKeepTables.Wrap(rows.Close()))
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
tables = append(tables, tableName)
|
|
|
|
}
|
|
|
|
err = rows.Close()
|
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrKeepTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Loop over the list of tables and decide which ones to keep and which to drop.
|
|
|
|
for _, tableName := range tables {
|
|
|
|
if !tableToKeep(tableName, tablesToKeep) {
|
|
|
|
// Drop tables we aren't told to keep in the destination database.
|
2019-09-23 20:36:46 +01:00
|
|
|
_, err = tx.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
|
2019-09-17 22:42:40 +01:00
|
|
|
if err != nil {
|
2019-09-19 22:21:03 +01:00
|
|
|
return ErrKeepTables.Wrap(err)
|
2019-09-17 22:42:40 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func tableToKeep(table string, tables []string) bool {
|
|
|
|
for _, t := range tables {
|
|
|
|
if t == table {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|