storj/private/dbutil/dbschema/data.go
Jeff Wendling 074649835b satellite/satellitedb: add some docs and improve some snapshots
This attempts to add a README.md to help create consistent migrations
that maximize our test coverage and do not include unnecessary
statements.

It also adds a feature to have an `-- OLD DATA --` section as well
as a `-- NEW DATA --` section so that we can fix mistakes made in
previous snapshots (like a row that was forgotten to be added when a
table was created) without editing them going forward.

Change-Id: I28a786f8ef163cae1de1bb08f61af1e1104b0a88
2020-05-22 21:27:36 +00:00

161 lines
3.3 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package dbschema
import (
"context"
"fmt"
"sort"
"strings"
"github.com/zeebo/errs"
)
// Data is the database content formatted as strings
type Data struct {
Tables []*TableData
}
// TableData is content of a sql table
type TableData struct {
Name string
Columns []string
Rows []RowData
}
// ColumnData is a value of a column within a row.
type ColumnData struct {
Column string
Value string
}
// String returns a string representation of the column.
func (c ColumnData) String() string {
return fmt.Sprintf("%s:%s", c.Column, c.Value)
}
// RowData is content of a single row
type RowData []ColumnData
// Less returns true if one row is less than the other.
func (row RowData) Less(b RowData) bool {
n := len(row)
if len(b) < n {
n = len(b)
}
for k := 0; k < n; k++ {
if row[k].Value < b[k].Value {
return true
} else if row[k].Value > b[k].Value {
return false
}
}
return len(row) < len(b)
}
// AddTable adds a new table.
func (data *Data) AddTable(table *TableData) {
data.Tables = append(data.Tables, table)
}
// AddRow adds a new row.
func (table *TableData) AddRow(row RowData) error {
if len(row) != len(table.Columns) {
return errs.New("inconsistent row added to table")
}
for i, cdata := range row {
if cdata.Column != table.Columns[i] {
return errs.New("inconsistent row added to table")
}
}
table.Rows = append(table.Rows, row)
return nil
}
// FindTable finds a table by name
func (data *Data) FindTable(tableName string) (*TableData, bool) {
for _, table := range data.Tables {
if table.Name == tableName {
return table, true
}
}
return nil, false
}
// Sort sorts all tables.
func (data *Data) Sort() {
for _, table := range data.Tables {
table.Sort()
}
}
// Sort sorts all rows.
func (table *TableData) Sort() {
sort.Slice(table.Rows, func(i, k int) bool {
return table.Rows[i].Less(table.Rows[k])
})
}
// Clone returns a clone of row data.
func (row RowData) Clone() RowData {
return append(RowData{}, row...)
}
// QueryData loads all data from tables
func QueryData(ctx context.Context, db Queryer, schema *Schema, quoteColumn func(string) string) (*Data, error) {
data := &Data{}
for _, tableSchema := range schema.Tables {
columnNames := tableSchema.ColumnNames()
table := &TableData{
Name: tableSchema.Name,
Columns: columnNames,
}
data.AddTable(table)
// quote column names
quotedColumns := make([]string, len(columnNames))
for i, columnName := range columnNames {
quotedColumns[i] = quoteColumn(columnName)
}
// build query for selecting all values
query := `SELECT ` + strings.Join(quotedColumns, ", ") + ` FROM ` + table.Name
err := func() (err error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
row := make(RowData, len(columnNames))
rowargs := make([]interface{}, len(columnNames))
for i := range row {
row[i].Column = columnNames[i]
rowargs[i] = &row[i].Value
}
for rows.Next() {
err := rows.Scan(rowargs...)
if err != nil {
return err
}
if err := table.AddRow(row.Clone()); err != nil {
return err
}
}
return rows.Err()
}()
if err != nil {
return nil, err
}
}
data.Sort()
return data, nil
}