diff --git a/internal/dbutil/dbschema/schema.go b/internal/dbutil/dbschema/schema.go new file mode 100644 index 000000000..d97b2df29 --- /dev/null +++ b/internal/dbutil/dbschema/schema.go @@ -0,0 +1,115 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// Package dbschema package implements querying and comparing schemas for testing. +package dbschema + +import ( + "sort" +) + +// Schema is the database structure. +type Schema struct { + Tables []*Table + Indexes []*Index +} + +// Table is a sql table. +type Table struct { + Name string + Columns []*Column + PrimaryKey []string + Unique [][]string +} + +// Column is a sql column. +type Column struct { + Name string + Type string + IsNullable bool + Reference *Reference +} + +// Reference is a column foreign key. +type Reference struct { + Table string + Column string + OnDelete string + OnUpdate string +} + +// Index is an index for a table. +type Index struct { + Name string + Table string + Columns []string + Unique bool +} + +// EnsureTable returns the table with the specified name and creates one if needed. +func (schema *Schema) EnsureTable(tableName string) *Table { + for _, table := range schema.Tables { + if table.Name == tableName { + return table + } + } + table := &Table{Name: tableName} + schema.Tables = append(schema.Tables, table) + return table +} + +// AddColumn adds the column to the table. +func (table *Table) AddColumn(column *Column) { + table.Columns = append(table.Columns, column) +} + +// FindColumn finds a column in the table +func (table *Table) FindColumn(columnName string) (*Column, bool) { + for _, column := range table.Columns { + if column.Name == columnName { + return column, true + } + } + return nil, false +} + +// Sort sorts tables +func (schema *Schema) Sort() { + sort.Slice(schema.Tables, func(i, k int) bool { + return schema.Tables[i].Name < schema.Tables[k].Name + }) + for _, table := range schema.Tables { + table.Sort() + } +} + +// Sort sorts columns, primary keys and unique +func (table *Table) Sort() { + sort.Slice(table.Columns, func(i, k int) bool { + return table.Columns[i].Name < table.Columns[k].Name + }) + + sort.Strings(table.PrimaryKey) + for i := range table.Unique { + sort.Strings(table.Unique[i]) + } + + sort.Slice(table.Unique, func(i, k int) bool { + return lessStrings(table.Unique[i], table.Unique[k]) + }) +} + +func lessStrings(a, b []string) bool { + n := len(a) + if len(b) < n { + n = len(b) + } + for k := 0; k < n; k++ { + if a[k] < b[k] { + return true + } else if a[k] > b[k] { + return false + } + } + return len(a) < len(b) +} diff --git a/internal/dbutil/pgutil/query.go b/internal/dbutil/pgutil/query.go new file mode 100644 index 000000000..3c93a8116 --- /dev/null +++ b/internal/dbutil/pgutil/query.go @@ -0,0 +1,145 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package pgutil + +import ( + "database/sql" + "fmt" + "regexp" + + "github.com/lib/pq" + "github.com/zeebo/errs" + + "storj.io/storj/internal/dbutil/dbschema" +) + +// Queryer is a representation for something that can query. +type Queryer interface { + // Query executes a query that returns rows, typically a SELECT. + Query(query string, args ...interface{}) (*sql.Rows, error) +} + +// QuerySchema loads the schema from postgres database. +func QuerySchema(tx Queryer) (*dbschema.Schema, error) { + schema := &dbschema.Schema{} + + // find tables + err := func() error { + rows, err := tx.Query(` + SELECT table_name, column_name, is_nullable, data_type + FROM information_schema.columns + WHERE table_schema = CURRENT_SCHEMA + `) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, rows.Close()) }() + + for rows.Next() { + var tableName, columnName, isNullable, dataType string + err := rows.Scan(&tableName, &columnName, &isNullable, &dataType) + if err != nil { + return err + } + + table := schema.EnsureTable(tableName) + table.AddColumn(&dbschema.Column{ + Name: columnName, + Type: dataType, + IsNullable: isNullable == "YES", + }) + } + + return rows.Err() + }() + if err != nil { + return schema, err + } + + // find constraints + err = func() error { + rows, err := tx.Query(` + SELECT pg_class.relname AS table_name, + pg_constraint.conname AS constraint_name, + pg_constraint.contype AS constraint_type, + ARRAY_AGG(pg_attribute.attname ORDER BY u.attposition) AS columns, + pg_get_constraintdef(pg_constraint.oid) AS definition + FROM pg_constraint pg_constraint + JOIN LATERAL UNNEST(pg_constraint.conkey) WITH ORDINALITY AS u(attnum, attposition) ON TRUE + JOIN pg_class ON pg_class.oid = pg_constraint.conrelid + JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace + JOIN pg_attribute ON (pg_attribute.attrelid = pg_class.oid AND pg_attribute.attnum = u.attnum) + WHERE pg_namespace.nspname = CURRENT_SCHEMA + GROUP BY constraint_name, constraint_type, table_name, definition + `) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, rows.Close()) }() + + for rows.Next() { + var tableName, constraintName, constraintType string + var columns pq.StringArray + var definition string + + err := rows.Scan(&tableName, &constraintName, &constraintType, &columns, &definition) + if err != nil { + return err + } + + switch constraintType { + case "p": // primary key + table := schema.EnsureTable(tableName) + table.PrimaryKey = ([]string)(columns) + case "f": // foreign key + if len(columns) != 1 { + return fmt.Errorf("expected one column, got: %q", columns) + } + + table := schema.EnsureTable(tableName) + column, ok := table.FindColumn(columns[0]) + if !ok { + return fmt.Errorf("did not find column %q", columns[0]) + } + + matches := rxPostgresForeignKey.FindStringSubmatch(definition) + if len(matches) == 0 { + return fmt.Errorf("unable to parse constraint %q", definition) + } + + column.Reference = &dbschema.Reference{ + Table: matches[1], + Column: matches[2], + OnUpdate: matches[3], + OnDelete: matches[4], + } + case "u": // unique + table := schema.EnsureTable(tableName) + table.Unique = append(table.Unique, columns) + default: + return fmt.Errorf("unhandled constraint type %q", constraintType) + } + + if err != nil { + return err + } + } + return rows.Err() + }() + if err != nil { + return nil, err + } + + // TODO: find indexes + + return schema, nil +} + +// matches FOREIGN KEY (project_id) REFERENCES projects(id) ON UPDATE CASCADE ON DELETE CASCADE +var rxPostgresForeignKey = regexp.MustCompile( + `^FOREIGN KEY \([[:word:]]+\) ` + + `REFERENCES ([[:word:]]+)\(([[:word:]]+)\)` + + `(?:\s*ON UPDATE ([[:word:]]+))?` + + `(?:\s*ON DELETE ([[:word:]]+))?$`, +) diff --git a/internal/dbutil/pgutil/query_test.go b/internal/dbutil/pgutil/query_test.go new file mode 100644 index 000000000..de659888a --- /dev/null +++ b/internal/dbutil/pgutil/query_test.go @@ -0,0 +1,118 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package pgutil_test + +import ( + "database/sql" + "flag" + "os" + "testing" + + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "storj.io/storj/internal/dbutil/dbschema" + "storj.io/storj/internal/dbutil/pgutil" + "storj.io/storj/internal/testcontext" +) + +const ( + // DefaultPostgresConn is a connstring that works with docker-compose + DefaultPostgresConn = "postgres://storj:storj-pass@test-postgres/teststorj?sslmode=disable" +) + +var ( + // TestPostgres is flag for the postgres test database + TestPostgres = flag.String("postgres-test-db", os.Getenv("STORJ_POSTGRES_TEST"), "PostgreSQL test database connection string") +) + +func TestQuery(t *testing.T) { + if *TestPostgres == "" { + t.Skip("Postgres flag missing, example: -postgres-test-db=" + DefaultPostgresConn) + } + + ctx := testcontext.New(t) + defer ctx.Cleanup() + + schemaName := "pgutil-query-" + pgutil.RandomString(8) + connstr := pgutil.ConnstrWithSchema(*TestPostgres, schemaName) + + db, err := sql.Open("postgres", connstr) + require.NoError(t, err) + + defer ctx.Check(db.Close) + + require.NoError(t, pgutil.CreateSchema(db, schemaName)) + defer func() { + require.NoError(t, pgutil.DropSchema(db, schemaName)) + }() + + emptySchema, err := pgutil.QuerySchema(db) + assert.NoError(t, err) + assert.Equal(t, &dbschema.Schema{}, emptySchema) + + _, err = db.Exec(` + CREATE TABLE users ( + a integer NOT NULL, + b integer NOT NULL, + c text, + UNIQUE (c), + PRIMARY KEY (a) + ); + CREATE TABLE names ( + users_a integer REFERENCES users( a ) ON DELETE CASCADE, + a text NOT NULL, + x text, + b text, + PRIMARY KEY (a, x), + UNIQUE ( x ), + UNIQUE ( a, b ) + ); + `) + require.NoError(t, err) + + schema, err := pgutil.QuerySchema(db) + assert.NoError(t, err) + + expected := &dbschema.Schema{ + Tables: []*dbschema.Table{ + { + Name: "users", + Columns: []*dbschema.Column{ + {Name: "a", Type: "integer", IsNullable: false, Reference: nil}, + {Name: "b", Type: "integer", IsNullable: false, Reference: nil}, + {Name: "c", Type: "text", IsNullable: true, Reference: nil}, + }, + PrimaryKey: []string{"a"}, + Unique: [][]string{ + {"c"}, + }, + }, + { + Name: "names", + Columns: []*dbschema.Column{ + {Name: "users_a", Type: "integer", IsNullable: true, + Reference: &dbschema.Reference{ + Table: "users", + Column: "a", + OnDelete: "CASCADE", + }}, + {Name: "a", Type: "text", IsNullable: false, Reference: nil}, + {Name: "x", Type: "text", IsNullable: false, Reference: nil}, // not null, because primary key + {Name: "b", Type: "text", IsNullable: true, Reference: nil}, + }, + PrimaryKey: []string{"a", "x"}, + Unique: [][]string{ + {"a", "b"}, + {"x"}, + }, + }, + }, + } + + expected.Sort() + schema.Sort() + assert.Equal(t, expected, schema) +} diff --git a/internal/dbutil/pgutil/schema.go b/internal/dbutil/pgutil/schema.go new file mode 100644 index 000000000..f8e638323 --- /dev/null +++ b/internal/dbutil/pgutil/schema.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// Package pgutil contains utilities for postgres +package pgutil + +import ( + "crypto/rand" + "database/sql" + "encoding/hex" + "net/url" + "strconv" + "strings" +) + +// RandomString creates a random safe string +func RandomString(n int) string { + data := make([]byte, n) + _, _ = rand.Read(data) + return hex.EncodeToString(data) +} + +// ConnstrWithSchema adds schema to a connection string +func ConnstrWithSchema(connstr, schema string) string { + schema = strings.ToLower(schema) + return connstr + "&search_path=" + url.QueryEscape(schema) +} + +// QuoteSchema quotes schema name for +func QuoteSchema(schema string) string { + return strconv.QuoteToASCII(schema) +} + +// Execer is for executing sql +type Execer interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// CreateSchema creates a schema if it doesn't exist. +func CreateSchema(db Execer, schema string) error { + _, err := db.Exec(`create schema if not exists ` + QuoteSchema(schema) + `;`) + return err +} + +// DropSchema drops the named schema +func DropSchema(db Execer, schema string) error { + _, err := db.Exec(`drop schema ` + QuoteSchema(schema) + ` cascade;`) + return err +} diff --git a/internal/testplanet/run.go b/internal/testplanet/run.go index e98471ce0..b7fd796e8 100644 --- a/internal/testplanet/run.go +++ b/internal/testplanet/run.go @@ -4,8 +4,6 @@ package testplanet import ( - "crypto/rand" - "encoding/hex" "strconv" "strings" "testing" @@ -13,6 +11,7 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap/zaptest" + "storj.io/storj/internal/dbutil/pgutil" "storj.io/storj/internal/testcontext" "storj.io/storj/satellite" "storj.io/storj/satellite/satellitedb" @@ -21,7 +20,7 @@ import ( // Run runs testplanet in multiple configurations. func Run(t *testing.T, config Config, test func(t *testing.T, ctx *testcontext.Context, planet *Planet)) { - schemaSuffix := randomSchemaSuffix() + schemaSuffix := pgutil.RandomString(8) t.Log("schema-suffix ", schemaSuffix) for _, satelliteDB := range satellitedbtest.Databases() { @@ -40,7 +39,7 @@ func Run(t *testing.T, config Config, test func(t *testing.T, ctx *testcontext.C planetConfig.Reconfigure.NewBootstrapDB = nil planetConfig.Reconfigure.NewSatelliteDB = func(index int) (satellite.DB, error) { schema := strings.ToLower(t.Name() + "-satellite/" + strconv.Itoa(index) + "-" + schemaSuffix) - db, err := satellitedb.New(satellitedbtest.WithSchema(satelliteDB.URL, schema)) + db, err := satellitedb.New(pgutil.ConnstrWithSchema(satelliteDB.URL, schema)) if err != nil { t.Fatal(err) } @@ -81,9 +80,3 @@ func (db *satelliteSchema) Close() error { db.DB.Close(), ) } - -func randomSchemaSuffix() string { - var data [8]byte - _, _ = rand.Read(data[:]) - return hex.EncodeToString(data[:]) -} diff --git a/satellite/satellitedb/database.go b/satellite/satellitedb/database.go index 335d63898..e13b4d3f6 100644 --- a/satellite/satellitedb/database.go +++ b/satellite/satellitedb/database.go @@ -4,10 +4,9 @@ package satellitedb import ( - "strconv" - "github.com/zeebo/errs" + "storj.io/storj/internal/dbutil/pgutil" "storj.io/storj/internal/migrate" "storj.io/storj/pkg/accounting" "storj.io/storj/pkg/bwagreement" @@ -60,12 +59,16 @@ func NewInMemory() (satellite.DB, error) { return New("sqlite3://file::memory:?mode=memory") } +// Close is used to close db connection +func (db *DB) Close() error { + return db.db.Close() +} + // CreateSchema creates a schema if it doesn't exist. func (db *DB) CreateSchema(schema string) error { switch db.driver { case "postgres": - _, err := db.db.Exec(`create schema if not exists ` + quoteSchema(schema) + `;`) - return err + return pgutil.CreateSchema(db.db, schema) } return nil } @@ -74,17 +77,11 @@ func (db *DB) CreateSchema(schema string) error { func (db *DB) DropSchema(schema string) error { switch db.driver { case "postgres": - _, err := db.db.Exec(`drop schema ` + quoteSchema(schema) + ` cascade;`) - return err + return pgutil.DropSchema(db.db, schema) } return nil } -// quoteSchema quotes schema name such that it can be used in a postgres query -func quoteSchema(schema string) string { - return strconv.QuoteToASCII(schema) -} - // BandwidthAgreement is a getter for bandwidth agreement repository func (db *DB) BandwidthAgreement() bwagreement.DB { return &bandwidthagreement{db: db.db} @@ -137,8 +134,3 @@ func (db *DB) Console() console.DB { func (db *DB) CreateTables() error { return migrate.Create("database", db.db) } - -// Close is used to close db connection -func (db *DB) Close() error { - return db.db.Close() -} diff --git a/satellite/satellitedb/satellitedbtest/run.go b/satellite/satellitedb/satellitedbtest/run.go index 7d0dc2e81..6de6b4066 100644 --- a/satellite/satellitedb/satellitedbtest/run.go +++ b/satellite/satellitedb/satellitedbtest/run.go @@ -6,16 +6,14 @@ package satellitedbtest // This package should be referenced only in test files! import ( - "crypto/rand" - "encoding/hex" "flag" - "net/url" "os" "strings" "testing" "github.com/zeebo/errs" + "storj.io/storj/internal/dbutil/pgutil" "storj.io/storj/satellite" "storj.io/storj/satellite/satellitedb" ) @@ -47,18 +45,10 @@ func Databases() []Database { } } -// WithSchema adds schema param to connection string. -func WithSchema(connstring string, schema string) string { - if strings.HasPrefix(connstring, "postgres") { - return connstring + "&search_path=" + url.QueryEscape(schema) - } - return connstring -} - // Run method will iterate over all supported databases. Will establish // connection and will create tables for each DB. func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) { - schemaSuffix := randomSchemaSuffix() + schemaSuffix := pgutil.RandomString(8) t.Log("schema-suffix ", schemaSuffix) for _, dbInfo := range Databases() { @@ -71,7 +61,8 @@ func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) { } schema := strings.ToLower(t.Name() + "-satellite/x-" + schemaSuffix) - db, err := satellitedb.New(WithSchema(dbInfo.URL, schema)) + connstr := pgutil.ConnstrWithSchema(dbInfo.URL, schema) + db, err := satellitedb.New(connstr) if err != nil { t.Fatal(err) } @@ -98,9 +89,3 @@ func Run(t *testing.T, test func(t *testing.T, db satellite.DB)) { }) } } - -func randomSchemaSuffix() string { - var data [8]byte - _, _ = rand.Read(data[:]) - return hex.EncodeToString(data[:]) -}