storj/multinode/multinodedb/migrate_test.go

243 lines
6.4 KiB
Go
Raw Normal View History

// Copyright (C) 2021 Storj Labs, Inc.
// See LICENSE for copying information.
package multinodedb_test
import (
"context"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"go.uber.org/zap/zaptest"
"storj.io/common/testcontext"
"storj.io/private/dbutil/dbschema"
"storj.io/private/dbutil/pgtest"
"storj.io/private/dbutil/pgutil"
"storj.io/private/dbutil/sqliteutil"
"storj.io/private/dbutil/tempdb"
"storj.io/storj/multinode/multinodedb"
)
func TestMigrateSQLite3(t *testing.T) {
ctx := testcontext.NewWithTimeout(t, 8*time.Minute)
defer ctx.Cleanup()
log := zaptest.NewLogger(t)
dbURL := "sqlite3://file::memory:"
db, err := multinodedb.Open(ctx, log, dbURL)
require.NoError(t, err)
defer func() {
require.NoError(t, db.Close())
}()
// get snapshots
// find all sqlite3 sql files
matches, err := filepath.Glob("testdata/sqlite3.*")
require.NoError(t, err)
snapshots := new(dbschema.Snapshots)
snapshots.List = make([]*dbschema.Snapshot, len(matches))
for i, match := range matches {
version := parseTestdataVersion(match, "sqlite3")
require.True(t, version >= 0, "invalid testdata file %q: %v", match, err)
scriptData, err := os.ReadFile(match)
require.NoError(t, err, "could not read testdata file for version %d: %v", version, err)
// exec per snapshot??
snapshot, err := sqliteutil.LoadSnapshotFromSQL(ctx, string(scriptData))
require.NoError(t, err)
snapshot.Version = version
snapshots.List[i] = snapshot
}
snapshots.Sort()
// get latest schema
schema, err := sqliteutil.LoadSchemaFromSQL(ctx, db.Schema())
require.NoError(t, err)
var finalSchema *dbschema.Schema
migration := db.SQLite3Migration()
for i, step := range migration.Steps {
tag := fmt.Sprintf("#%d - v%d", i, step.Version)
expected, ok := snapshots.FindVersion(step.Version)
require.True(t, ok)
err = migration.TargetVersion(step.Version).Run(ctx, log)
require.NoError(t, err)
if newData := expected.LookupSection(dbschema.NewData); newData != "" {
_, err = db.ExecContext(ctx, newData)
require.NoError(t, err)
}
currentSchema, err := sqliteutil.QuerySchema(ctx, db)
require.NoError(t, err)
currentSchema.DropTable("versions")
currentData, err := sqliteutil.QueryData(ctx, db, currentSchema)
require.NoError(t, err)
require.Equal(t, expected.Schema, currentSchema, tag)
require.Equal(t, expected.Data, currentData, tag)
finalSchema = currentSchema
}
// verify that we also match the dbx version
require.Equal(t, schema, finalSchema, "result of all migration scripts did not match dbx schema")
}
func TestMigratePostgres(t *testing.T) {
ctx := testcontext.NewWithTimeout(t, 8*time.Minute)
defer ctx.Cleanup()
log := zaptest.NewLogger(t)
connStr := pgtest.PickPostgres(t)
// create tempDB
tempDB, err := tempdb.OpenUnique(ctx, connStr, "migrate")
require.NoError(t, err)
defer func() {
require.NoError(t, tempDB.Close())
}()
db, err := multinodedb.Open(ctx, log, tempDB.ConnStr)
require.NoError(t, err)
defer func() {
require.NoError(t, db.Close())
}()
// get snapshots
// find all postgres sql files
matches, err := filepath.Glob("testdata/postgres.*")
require.NoError(t, err)
snapshots := new(dbschema.Snapshots)
snapshots.List = make([]*dbschema.Snapshot, len(matches))
for i, match := range matches {
version := parseTestdataVersion(match, "postgres")
require.True(t, version >= 0, "invalid testdata file %q: %v", match, err)
scriptData, err := os.ReadFile(match)
require.NoError(t, err, "could not read testdata file for version %d: %v", version, err)
snapshot, err := loadSnapshotFromSQLPostgres(ctx, connStr, string(scriptData))
require.NoError(t, err)
snapshot.Version = version
snapshots.List[i] = snapshot
}
snapshots.Sort()
// get latest schema
schema, err := loadSchemaFromSQLPostgres(ctx, connStr, db.Schema())
require.NoError(t, err)
var finalSchema *dbschema.Schema
migration := db.PostgresMigration()
for i, step := range migration.Steps {
tag := fmt.Sprintf("#%d - v%d", i, step.Version)
expected, ok := snapshots.FindVersion(step.Version)
require.True(t, ok)
err = migration.TargetVersion(step.Version).Run(ctx, log)
require.NoError(t, err)
if newData := expected.LookupSection(dbschema.NewData); newData != "" {
_, err = db.ExecContext(ctx, newData)
require.NoError(t, err)
}
currentSchema, err := pgutil.QuerySchema(ctx, db)
require.NoError(t, err)
currentSchema.DropTable("versions")
currentData, err := pgutil.QueryData(ctx, db, currentSchema)
require.NoError(t, err)
require.Equal(t, expected.Schema, currentSchema, tag)
require.Equal(t, expected.Data, currentData, tag)
finalSchema = currentSchema
}
// verify that we also match the dbx version
require.Equal(t, schema, finalSchema, "result of all migration scripts did not match dbx schema")
}
func parseTestdataVersion(path string, impl string) int {
path = filepath.ToSlash(strings.ToLower(path))
path = strings.TrimPrefix(path, "testdata/"+impl+".v")
path = strings.TrimSuffix(path, ".sql")
v, err := strconv.Atoi(path)
if err != nil {
return -1
}
return v
}
// loadSnapshotFromSQLPostgres inserts script into connstr and loads snapshot for postgres db.
func loadSnapshotFromSQLPostgres(ctx context.Context, connstr, script string) (_ *dbschema.Snapshot, err error) {
db, err := tempdb.OpenUnique(ctx, connstr, "load-schema")
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, db.Close()) }()
sections := dbschema.NewSections(script)
_, err = db.ExecContext(ctx, sections.LookupSection(dbschema.Main))
if err != nil {
return nil, err
}
_, err = db.ExecContext(ctx, sections.LookupSection(dbschema.MainData))
if err != nil {
return nil, err
}
_, err = db.ExecContext(ctx, sections.LookupSection(dbschema.NewData))
if err != nil {
return nil, err
}
snapshot, err := pgutil.QuerySnapshot(ctx, db)
if err != nil {
return nil, err
}
snapshot.Sections = sections
return snapshot, nil
}
// loadSnapshotFromSQLPostgres inserts script into connstr and loads schema for postgres db.
func loadSchemaFromSQLPostgres(ctx context.Context, connstr, script string) (_ *dbschema.Schema, err error) {
db, err := tempdb.OpenUnique(ctx, connstr, "load-schema")
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, db.Close()) }()
_, err = db.ExecContext(ctx, script)
if err != nil {
return nil, err
}
return pgutil.QuerySchema(ctx, db)
}