create dbutil package for migration testing (#1305)

This commit is contained in:
Egon Elbre 2019-02-13 18:06:34 +02:00 committed by GitHub
parent 2c5716e874
commit 497fb756fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 442 additions and 45 deletions

View File

@ -0,0 +1,115 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// Package dbschema package implements querying and comparing schemas for testing.
package dbschema
import (
"sort"
)
// Schema is the database structure.
type Schema struct {
Tables []*Table
Indexes []*Index
}
// Table is a sql table.
type Table struct {
Name string
Columns []*Column
PrimaryKey []string
Unique [][]string
}
// Column is a sql column.
type Column struct {
Name string
Type string
IsNullable bool
Reference *Reference
}
// Reference is a column foreign key.
type Reference struct {
Table string
Column string
OnDelete string
OnUpdate string
}
// Index is an index for a table.
type Index struct {
Name string
Table string
Columns []string
Unique bool
}
// EnsureTable returns the table with the specified name and creates one if needed.
func (schema *Schema) EnsureTable(tableName string) *Table {
for _, table := range schema.Tables {
if table.Name == tableName {
return table
}
}
table := &Table{Name: tableName}
schema.Tables = append(schema.Tables, table)
return table
}
// AddColumn adds the column to the table.
func (table *Table) AddColumn(column *Column) {
table.Columns = append(table.Columns, column)
}
// FindColumn finds a column in the table
func (table *Table) FindColumn(columnName string) (*Column, bool) {
for _, column := range table.Columns {
if column.Name == columnName {
return column, true
}
}
return nil, false
}
// Sort sorts tables
func (schema *Schema) Sort() {
sort.Slice(schema.Tables, func(i, k int) bool {
return schema.Tables[i].Name < schema.Tables[k].Name
})
for _, table := range schema.Tables {
table.Sort()
}
}
// Sort sorts columns, primary keys and unique
func (table *Table) Sort() {
sort.Slice(table.Columns, func(i, k int) bool {
return table.Columns[i].Name < table.Columns[k].Name
})
sort.Strings(table.PrimaryKey)
for i := range table.Unique {
sort.Strings(table.Unique[i])
}
sort.Slice(table.Unique, func(i, k int) bool {
return lessStrings(table.Unique[i], table.Unique[k])
})
}
func lessStrings(a, b []string) bool {
n := len(a)
if len(b) < n {
n = len(b)
}
for k := 0; k < n; k++ {
if a[k] < b[k] {
return true
} else if a[k] > b[k] {
return false
}
}
return len(a) < len(b)
}

View File

@ -0,0 +1,145 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package pgutil
import (
"database/sql"
"fmt"
"regexp"
"github.com/lib/pq"
"github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/dbschema"
)
// Queryer is a representation for something that can query.
type Queryer interface {
// Query executes a query that returns rows, typically a SELECT.
Query(query string, args ...interface{}) (*sql.Rows, error)
}
// QuerySchema loads the schema from postgres database.
func QuerySchema(tx Queryer) (*dbschema.Schema, error) {
schema := &dbschema.Schema{}
// find tables
err := func() error {
rows, err := tx.Query(`
SELECT table_name, column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA
`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var tableName, columnName, isNullable, dataType string
err := rows.Scan(&tableName, &columnName, &isNullable, &dataType)
if err != nil {
return err
}
table := schema.EnsureTable(tableName)
table.AddColumn(&dbschema.Column{
Name: columnName,
Type: dataType,
IsNullable: isNullable == "YES",
})
}
return rows.Err()
}()
if err != nil {
return schema, err
}
// find constraints
err = func() error {
rows, err := tx.Query(`
SELECT pg_class.relname AS table_name,
pg_constraint.conname AS constraint_name,
pg_constraint.contype AS constraint_type,
ARRAY_AGG(pg_attribute.attname ORDER BY u.attposition) AS columns,
pg_get_constraintdef(pg_constraint.oid) AS definition
FROM pg_constraint pg_constraint
JOIN LATERAL UNNEST(pg_constraint.conkey) WITH ORDINALITY AS u(attnum, attposition) ON TRUE
JOIN pg_class ON pg_class.oid = pg_constraint.conrelid
JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
JOIN pg_attribute ON (pg_attribute.attrelid = pg_class.oid AND pg_attribute.attnum = u.attnum)
WHERE pg_namespace.nspname = CURRENT_SCHEMA
GROUP BY constraint_name, constraint_type, table_name, definition
`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var tableName, constraintName, constraintType string
var columns pq.StringArray
var definition string
err := rows.Scan(&tableName, &constraintName, &constraintType, &columns, &definition)
if err != nil {
return err
}
switch constraintType {
case "p": // primary key
table := schema.EnsureTable(tableName)
table.PrimaryKey = ([]string)(columns)
case "f": // foreign key
if len(columns) != 1 {
return fmt.Errorf("expected one column, got: %q", columns)
}
table := schema.EnsureTable(tableName)
column, ok := table.FindColumn(columns[0])
if !ok {
return fmt.Errorf("did not find column %q", columns[0])
}
matches := rxPostgresForeignKey.FindStringSubmatch(definition)
if len(matches) == 0 {
return fmt.Errorf("unable to parse constraint %q", definition)
}
column.Reference = &dbschema.Reference{
Table: matches[1],
Column: matches[2],
OnUpdate: matches[3],
OnDelete: matches[4],
}
case "u": // unique
table := schema.EnsureTable(tableName)
table.Unique = append(table.Unique, columns)
default:
return fmt.Errorf("unhandled constraint type %q", constraintType)
}
if err != nil {
return err
}
}
return rows.Err()
}()
if err != nil {
return nil, err
}
// TODO: find indexes
return schema, nil
}
// matches FOREIGN KEY (project_id) REFERENCES projects(id) ON UPDATE CASCADE ON DELETE CASCADE
var rxPostgresForeignKey = regexp.MustCompile(
`^FOREIGN KEY \([[:word:]]+\) ` +
`REFERENCES ([[:word:]]+)\(([[:word:]]+)\)` +
`(?:\s*ON UPDATE ([[:word:]]+))?` +
`(?:\s*ON DELETE ([[:word:]]+))?$`,
)

View File

@ -0,0 +1,118 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package pgutil_test
import (
"database/sql"
"flag"
"os"
"testing"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"storj.io/storj/internal/dbutil/dbschema"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/internal/testcontext"
)
const (
// DefaultPostgresConn is a connstring that works with docker-compose
DefaultPostgresConn = "postgres://storj:storj-pass@test-postgres/teststorj?sslmode=disable"
)
var (
// TestPostgres is flag for the postgres test database
TestPostgres = flag.String("postgres-test-db", os.Getenv("STORJ_POSTGRES_TEST"), "PostgreSQL test database connection string")
)
func TestQuery(t *testing.T) {
if *TestPostgres == "" {
t.Skip("Postgres flag missing, example: -postgres-test-db=" + DefaultPostgresConn)
}
ctx := testcontext.New(t)
defer ctx.Cleanup()
schemaName := "pgutil-query-" + pgutil.RandomString(8)
connstr := pgutil.ConnstrWithSchema(*TestPostgres, schemaName)
db, err := sql.Open("postgres", connstr)
require.NoError(t, err)
defer ctx.Check(db.Close)
require.NoError(t, pgutil.CreateSchema(db, schemaName))
defer func() {
require.NoError(t, pgutil.DropSchema(db, schemaName))
}()
emptySchema, err := pgutil.QuerySchema(db)
assert.NoError(t, err)
assert.Equal(t, &dbschema.Schema{}, emptySchema)
_, err = db.Exec(`
CREATE TABLE users (
a integer NOT NULL,
b integer NOT NULL,
c text,
UNIQUE (c),
PRIMARY KEY (a)
);
CREATE TABLE names (
users_a integer REFERENCES users( a ) ON DELETE CASCADE,
a text NOT NULL,
x text,
b text,
PRIMARY KEY (a, x),
UNIQUE ( x ),
UNIQUE ( a, b )
);
`)
require.NoError(t, err)
schema, err := pgutil.QuerySchema(db)
assert.NoError(t, err)
expected := &dbschema.Schema{
Tables: []*dbschema.Table{
{
Name: "users",
Columns: []*dbschema.Column{
{Name: "a", Type: "integer", IsNullable: false, Reference: nil},
{Name: "b", Type: "integer", IsNullable: false, Reference: nil},
{Name: "c", Type: "text", IsNullable: true, Reference: nil},
},
PrimaryKey: []string{"a"},
Unique: [][]string{
{"c"},
},
},
{
Name: "names",
Columns: []*dbschema.Column{
{Name: "users_a", Type: "integer", IsNullable: true,
Reference: &dbschema.Reference{
Table: "users",
Column: "a",
OnDelete: "CASCADE",
}},
{Name: "a", Type: "text", IsNullable: false, Reference: nil},
{Name: "x", Type: "text", IsNullable: false, Reference: nil}, // not null, because primary key
{Name: "b", Type: "text", IsNullable: true, Reference: nil},
},
PrimaryKey: []string{"a", "x"},
Unique: [][]string{
{"a", "b"},
{"x"},
},
},
},
}
expected.Sort()
schema.Sort()
assert.Equal(t, expected, schema)
}

View File

@ -0,0 +1,49 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// Package pgutil contains utilities for postgres
package pgutil
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"net/url"
"strconv"
"strings"
)
// RandomString creates a random safe string
func RandomString(n int) string {
data := make([]byte, n)
_, _ = rand.Read(data)
return hex.EncodeToString(data)
}
// ConnstrWithSchema adds schema to a connection string
func ConnstrWithSchema(connstr, schema string) string {
schema = strings.ToLower(schema)
return connstr + "&search_path=" + url.QueryEscape(schema)
}
// QuoteSchema quotes schema name for
func QuoteSchema(schema string) string {
return strconv.QuoteToASCII(schema)
}
// Execer is for executing sql
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// CreateSchema creates a schema if it doesn't exist.
func CreateSchema(db Execer, schema string) error {
_, err := db.Exec(`create schema if not exists ` + QuoteSchema(schema) + `;`)
return err
}
// DropSchema drops the named schema
func DropSchema(db Execer, schema string) error {
_, err := db.Exec(`drop schema ` + QuoteSchema(schema) + ` cascade;`)
return err
}

View File

@ -4,8 +4,6 @@
package testplanet package testplanet
import ( import (
"crypto/rand"
"encoding/hex"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -13,6 +11,7 @@ import (
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/internal/testcontext" "storj.io/storj/internal/testcontext"
"storj.io/storj/satellite" "storj.io/storj/satellite"
"storj.io/storj/satellite/satellitedb" "storj.io/storj/satellite/satellitedb"
@ -21,7 +20,7 @@ import (
// Run runs testplanet in multiple configurations. // Run runs testplanet in multiple configurations.
func Run(t *testing.T, config Config, test func(t *testing.T, ctx *testcontext.Context, planet *Planet)) { func Run(t *testing.T, config Config, test func(t *testing.T, ctx *testcontext.Context, planet *Planet)) {
schemaSuffix := randomSchemaSuffix() schemaSuffix := pgutil.RandomString(8)
t.Log("schema-suffix ", schemaSuffix) t.Log("schema-suffix ", schemaSuffix)
for _, satelliteDB := range satellitedbtest.Databases() { for _, satelliteDB := range satellitedbtest.Databases() {
@ -40,7 +39,7 @@ func Run(t *testing.T, config Config, test func(t *testing.T, ctx *testcontext.C
planetConfig.Reconfigure.NewBootstrapDB = nil planetConfig.Reconfigure.NewBootstrapDB = nil
planetConfig.Reconfigure.NewSatelliteDB = func(index int) (satellite.DB, error) { planetConfig.Reconfigure.NewSatelliteDB = func(index int) (satellite.DB, error) {
schema := strings.ToLower(t.Name() + "-satellite/" + strconv.Itoa(index) + "-" + schemaSuffix) schema := strings.ToLower(t.Name() + "-satellite/" + strconv.Itoa(index) + "-" + schemaSuffix)
db, err := satellitedb.New(satellitedbtest.WithSchema(satelliteDB.URL, schema)) db, err := satellitedb.New(pgutil.ConnstrWithSchema(satelliteDB.URL, schema))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -81,9 +80,3 @@ func (db *satelliteSchema) Close() error {
db.DB.Close(), db.DB.Close(),
) )
} }
func randomSchemaSuffix() string {
var data [8]byte
_, _ = rand.Read(data[:])
return hex.EncodeToString(data[:])
}

View File

@ -4,10 +4,9 @@
package satellitedb package satellitedb
import ( import (
"strconv"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/internal/migrate" "storj.io/storj/internal/migrate"
"storj.io/storj/pkg/accounting" "storj.io/storj/pkg/accounting"
"storj.io/storj/pkg/bwagreement" "storj.io/storj/pkg/bwagreement"
@ -60,12 +59,16 @@ func NewInMemory() (satellite.DB, error) {
return New("sqlite3://file::memory:?mode=memory") return New("sqlite3://file::memory:?mode=memory")
} }
// Close is used to close db connection
func (db *DB) Close() error {
return db.db.Close()
}
// CreateSchema creates a schema if it doesn't exist. // CreateSchema creates a schema if it doesn't exist.
func (db *DB) CreateSchema(schema string) error { func (db *DB) CreateSchema(schema string) error {
switch db.driver { switch db.driver {
case "postgres": case "postgres":
_, err := db.db.Exec(`create schema if not exists ` + quoteSchema(schema) + `;`) return pgutil.CreateSchema(db.db, schema)
return err
} }
return nil return nil
} }
@ -74,17 +77,11 @@ func (db *DB) CreateSchema(schema string) error {
func (db *DB) DropSchema(schema string) error { func (db *DB) DropSchema(schema string) error {
switch db.driver { switch db.driver {
case "postgres": case "postgres":
_, err := db.db.Exec(`drop schema ` + quoteSchema(schema) + ` cascade;`) return pgutil.DropSchema(db.db, schema)
return err
} }
return nil return nil
} }
// quoteSchema quotes schema name such that it can be used in a postgres query
func quoteSchema(schema string) string {
return strconv.QuoteToASCII(schema)
}
// BandwidthAgreement is a getter for bandwidth agreement repository // BandwidthAgreement is a getter for bandwidth agreement repository
func (db *DB) BandwidthAgreement() bwagreement.DB { func (db *DB) BandwidthAgreement() bwagreement.DB {
return &bandwidthagreement{db: db.db} return &bandwidthagreement{db: db.db}
@ -137,8 +134,3 @@ func (db *DB) Console() console.DB {
func (db *DB) CreateTables() error { func (db *DB) CreateTables() error {
return migrate.Create("database", db.db) return migrate.Create("database", db.db)
} }
// Close is used to close db connection
func (db *DB) Close() error {
return db.db.Close()
}

View File

@ -6,16 +6,14 @@ package satellitedbtest
// This package should be referenced only in test files! // This package should be referenced only in test files!
import ( import (
"crypto/rand"
"encoding/hex"
"flag" "flag"
"net/url"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/satellite" "storj.io/storj/satellite"
"storj.io/storj/satellite/satellitedb" "storj.io/storj/satellite/satellitedb"
) )
@ -47,18 +45,10 @@ func Databases() []Database {
} }
} }
// WithSchema adds schema param to connection string.
func WithSchema(connstring string, schema string) string {
if strings.HasPrefix(connstring, "postgres") {
return connstring + "&search_path=" + url.QueryEscape(schema)
}
return connstring
}
// Run method will iterate over all supported databases. Will establish // Run method will iterate over all supported databases. Will establish
// connection and will create tables for each DB. // connection and will create tables for each DB.
func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) { func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) {
schemaSuffix := randomSchemaSuffix() schemaSuffix := pgutil.RandomString(8)
t.Log("schema-suffix ", schemaSuffix) t.Log("schema-suffix ", schemaSuffix)
for _, dbInfo := range Databases() { for _, dbInfo := range Databases() {
@ -71,7 +61,8 @@ func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) {
} }
schema := strings.ToLower(t.Name() + "-satellite/x-" + schemaSuffix) schema := strings.ToLower(t.Name() + "-satellite/x-" + schemaSuffix)
db, err := satellitedb.New(WithSchema(dbInfo.URL, schema)) connstr := pgutil.ConnstrWithSchema(dbInfo.URL, schema)
db, err := satellitedb.New(connstr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -98,9 +89,3 @@ func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) {
}) })
} }
} }
func randomSchemaSuffix() string {
var data [8]byte
_, _ = rand.Read(data[:])
return hex.EncodeToString(data[:])
}