storj/private/dbutil/dbschema/schema.go
Egon Elbre 36c461bd59 private/tagsql: track proper closing of rows and statements
This ensures that rows are closed to avoid leaks.
Also verifies that Err() is called, to ensure that no
error is left behind.

Change-Id: Idd1bec9bf479f40021da67b2c80ce83033149469
2020-06-05 18:25:43 +00:00

259 lines
5.8 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// Package dbschema package implements querying and comparing schemas for testing.
package dbschema
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"storj.io/storj/private/tagsql"
)
// Queryer is a representation for something that can query.
type Queryer interface {
// QueryRowContext executes a query that returns a single row.
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
// QueryContext executes a query that returns rows, typically a SELECT.
QueryContext(ctx context.Context, query string, args ...interface{}) (tagsql.Rows, error)
}
// Schema is the database structure.
type Schema struct {
Tables []*Table
Indexes []*Index
}
func (schema Schema) String() string {
var tables []string
for _, table := range schema.Tables {
tables = append(tables, table.String())
}
var indexes []string
for _, index := range schema.Indexes {
indexes = append(indexes, index.String())
}
return fmt.Sprintf("Tables:\n\t%s\nIndexes:\n\t%s",
indent(strings.Join(tables, "\n")),
indent(strings.Join(indexes, "\n")))
}
// Table is a sql table.
type Table struct {
Name string
Columns []*Column
PrimaryKey []string
Unique [][]string
}
func (table Table) String() string {
var columns []string
for _, column := range table.Columns {
columns = append(columns, column.String())
}
var uniques []string
for _, unique := range table.Unique {
uniques = append(uniques, strings.Join(unique, " "))
}
return fmt.Sprintf("Name: %s\nColumns:\n\t%s\nPrimaryKey: %s\nUniques:\n\t%s",
table.Name,
indent(strings.Join(columns, "\n")),
strings.Join(table.PrimaryKey, " "),
indent(strings.Join(uniques, "\n")))
}
// Column is a sql column.
type Column struct {
Name string
Type string
IsNullable bool
Default string
Reference *Reference
}
func (column Column) String() string {
return fmt.Sprintf("Name: %s\nType: %s\nNullable: %t\nDefault: %q\nReference: %s",
column.Name,
column.Type,
column.IsNullable,
column.Default,
column.Reference)
}
// Reference is a column foreign key.
type Reference struct {
Table string
Column string
OnDelete string
OnUpdate string
}
func (reference *Reference) String() string {
if reference == nil {
return "nil"
}
return fmt.Sprintf("Reference<Table: %s, Column: %s, OnDelete: %s, OnUpdate: %s>",
reference.Table,
reference.Column,
reference.OnDelete,
reference.OnUpdate)
}
// Index is an index for a table.
type Index struct {
Name string
Table string
Columns []string
Unique bool
Partial string // partial expression
}
func (index Index) String() string {
return fmt.Sprintf("Index<Table: %s, Name: %s, Columns: %s, Unique: %t, Partial: %q>",
index.Table,
index.Name,
indent(strings.Join(index.Columns, " ")),
index.Unique,
index.Partial)
}
// EnsureTable returns the table with the specified name and creates one if needed.
func (schema *Schema) EnsureTable(tableName string) *Table {
for _, table := range schema.Tables {
if table.Name == tableName {
return table
}
}
table := &Table{Name: tableName}
schema.Tables = append(schema.Tables, table)
return table
}
// DropTable removes the specified table.
func (schema *Schema) DropTable(tableName string) {
for i, table := range schema.Tables {
if table.Name == tableName {
schema.Tables = append(schema.Tables[:i], schema.Tables[i+1:]...)
break
}
}
j := 0
for _, index := range schema.Indexes {
if index.Table == tableName {
continue
}
schema.Indexes[j] = index
j++
}
schema.Indexes = schema.Indexes[:j:j]
}
// FindIndex finds index in the schema.
func (schema *Schema) FindIndex(name string) (*Index, bool) {
for _, idx := range schema.Indexes {
if idx.Name == name {
return idx, true
}
}
return nil, false
}
// DropIndex removes the specified index.
func (schema *Schema) DropIndex(name string) {
for i, idx := range schema.Indexes {
if idx.Name == name {
schema.Indexes = append(schema.Indexes[:i], schema.Indexes[i+1:]...)
return
}
}
}
// AddColumn adds the column to the table.
func (table *Table) AddColumn(column *Column) {
table.Columns = append(table.Columns, column)
}
// FindColumn finds a column in the table.
func (table *Table) FindColumn(columnName string) (*Column, bool) {
for _, column := range table.Columns {
if column.Name == columnName {
return column, true
}
}
return nil, false
}
// ColumnNames returns column names.
func (table *Table) ColumnNames() []string {
columns := make([]string, len(table.Columns))
for i, column := range table.Columns {
columns[i] = column.Name
}
return columns
}
// Sort sorts tables and indexes.
func (schema *Schema) Sort() {
sort.Slice(schema.Tables, func(i, k int) bool {
return schema.Tables[i].Name < schema.Tables[k].Name
})
for _, table := range schema.Tables {
table.Sort()
}
sort.Slice(schema.Indexes, func(i, k int) bool {
switch {
case schema.Indexes[i].Table < schema.Indexes[k].Table:
return true
case schema.Indexes[i].Table > schema.Indexes[k].Table:
return false
default:
return schema.Indexes[i].Name < schema.Indexes[k].Name
}
})
}
// Sort sorts columns, primary keys and unique.
func (table *Table) Sort() {
sort.Slice(table.Columns, func(i, k int) bool {
return table.Columns[i].Name < table.Columns[k].Name
})
sort.Strings(table.PrimaryKey)
for i := range table.Unique {
sort.Strings(table.Unique[i])
}
sort.Slice(table.Unique, func(i, k int) bool {
return lessStrings(table.Unique[i], table.Unique[k])
})
}
func lessStrings(a, b []string) bool {
n := len(a)
if len(b) < n {
n = len(b)
}
for k := 0; k < n; k++ {
if a[k] < b[k] {
return true
} else if a[k] > b[k] {
return false
}
}
return len(a) < len(b)
}
func indent(lines string) string {
return strings.TrimSpace(strings.ReplaceAll(lines, "\n", "\n\t"))
}