Create dbutil package for sqlite (#1311)

This commit is contained in:
Michal Niewrzal 2019-02-15 17:13:00 +01:00 committed by GitHub
parent df20597f67
commit 84c2f991d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 371 additions and 1 deletions

View File

@ -47,7 +47,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
return rows.Err()
}()
if err != nil {
return schema, err
return nil, err
}
// find constraints

View File

@ -0,0 +1,78 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package sqliteutil
import (
"database/sql"
"strconv"
"github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/dbschema"
)
// LoadSchemaFromSQL inserts script into connstr and loads schema.
func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, db.Close()) }()
_, err = db.Exec(script)
if err != nil {
return nil, err
}
return QuerySchema(db)
}
// LoadSnapshotFromSQL inserts script into connstr and loads schema.
func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, db.Close()) }()
_, err = db.Exec(script)
if err != nil {
return nil, err
}
snapshot, err := QuerySnapshot(db)
if err != nil {
return nil, err
}
snapshot.Script = script
return snapshot, nil
}
// QuerySnapshot loads snapshot from database
func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) {
schema, err := QuerySchema(db)
if err != nil {
return nil, err
}
data, err := QueryData(db, schema)
if err != nil {
return nil, err
}
return &dbschema.Snapshot{
Version: -1,
Schema: schema,
Data: data,
}, err
}
// QueryData loads all data from tables
func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) {
return dbschema.QueryData(db, schema, func(columnName string) string {
quoted := strconv.Quote(columnName)
return `quote(` + quoted + `) as ` + quoted
})
}

View File

@ -0,0 +1,188 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package sqliteutil
import (
"database/sql"
"regexp"
"strings"
"github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/dbschema"
)
type definition struct {
name string
sql string
}
// QuerySchema loads the schema from sqlite database.
func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) {
schema := &dbschema.Schema{}
tableDefinitions := make([]*definition, 0)
indexDefinitions := make([]*definition, 0)
// find tables and indexes
err := func() error {
rows, err := db.Query(`
SELECT name, type, sql FROM sqlite_master WHERE sql NOT NULL AND name NOT LIKE 'sqlite_%'
`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var defName, defType, defSQL string
err := rows.Scan(&defName, &defType, &defSQL)
if err != nil {
return err
}
if defType == "table" {
tableDefinitions = append(tableDefinitions, &definition{name: defName, sql: defSQL})
} else if defType == "index" {
indexDefinitions = append(indexDefinitions, &definition{name: defName, sql: defSQL})
}
}
return rows.Err()
}()
if err != nil {
return nil, err
}
err = discoverTables(db, schema, tableDefinitions)
if err != nil {
return nil, err
}
err = discoverIndexes(db, schema, indexDefinitions)
if err != nil {
return nil, err
}
schema.Sort()
return schema, nil
}
func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) {
for _, definition := range tableDefinitions {
table := schema.EnsureTable(definition.name)
tableRows, err := db.Query(`PRAGMA table_info(` + definition.name + `)`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, tableRows.Close()) }()
for tableRows.Next() {
var defaultValue sql.NullString
var index, name, columnType string
var pk int
var notNull bool
err := tableRows.Scan(&index, &name, &columnType, &notNull, &defaultValue, &pk)
if err != nil {
return err
}
column := &dbschema.Column{
Name: name,
Type: columnType,
IsNullable: !notNull && pk == 0,
}
table.AddColumn(column)
if pk > 0 {
if table.PrimaryKey == nil {
table.PrimaryKey = make([]string, 0)
}
table.PrimaryKey = append(table.PrimaryKey, name)
}
}
matches := rxUnique.FindAllStringSubmatch(definition.sql, -1)
for _, match := range matches {
// TODO feel this can be done easier
var columns []string
for _, name := range strings.Split(match[1], ",") {
columns = append(columns, strings.TrimSpace(name))
}
table.Unique = append(table.Unique, columns)
}
keysRows, err := db.Query(`PRAGMA foreign_key_list(` + definition.name + `)`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, keysRows.Close()) }()
for keysRows.Next() {
var id, sec int
var tableName, from, to, onUpdate, onDelete, match string
err := keysRows.Scan(&id, &sec, &tableName, &from, &to, &onUpdate, &onDelete, &match)
if err != nil {
return err
}
column, found := table.FindColumn(from)
if found {
if onDelete == "NO ACTION" {
onDelete = ""
}
if onUpdate == "NO ACTION" {
onUpdate = ""
}
column.Reference = &dbschema.Reference{
Table: tableName,
Column: to,
OnUpdate: onUpdate,
OnDelete: onDelete,
}
}
}
}
return err
}
func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) {
// TODO improve indexes discovery
for _, definition := range indexDefinitions {
index := &dbschema.Index{
Name: definition.name,
}
schema.Indexes = append(schema.Indexes, index)
indexRows, err := db.Query(`PRAGMA index_info(` + definition.name + `)`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, indexRows.Close()) }()
for indexRows.Next() {
var name string
var seqno, cid int
err := indexRows.Scan(&seqno, &cid, &name)
if err != nil {
return err
}
index.Columns = append(index.Columns, name)
}
matches := rxIndexTable.FindStringSubmatch(definition.sql)
index.Table = strings.TrimSpace(matches[1])
}
return err
}
var (
// matches UNIQUE (a,b)
rxUnique = regexp.MustCompile(`UNIQUE\s*\((.*?)\)`)
// matches ON (a,b)
rxIndexTable = regexp.MustCompile(`ON\s*(.*)\(`)
)

View File

@ -0,0 +1,104 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package sqliteutil_test
import (
"database/sql"
"testing"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"storj.io/storj/internal/dbutil/dbschema"
"storj.io/storj/internal/dbutil/sqliteutil"
"storj.io/storj/internal/testcontext"
)
func TestQuery(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
db, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
defer ctx.Check(db.Close)
emptySchema, err := sqliteutil.QuerySchema(db)
assert.NoError(t, err)
assert.Equal(t, &dbschema.Schema{}, emptySchema)
_, err = db.Exec(`
CREATE TABLE users (
a integer NOT NULL,
b integer NOT NULL,
c text,
UNIQUE (c),
PRIMARY KEY (a)
);
CREATE TABLE names (
users_a integer REFERENCES users( a ) ON DELETE CASCADE,
a text NOT NULL,
x text,
b text,
PRIMARY KEY (a, x),
UNIQUE ( x ),
UNIQUE ( a, b )
);
CREATE INDEX names_a ON names (a, b);
`)
require.NoError(t, err)
schema, err := sqliteutil.QuerySchema(db)
assert.NoError(t, err)
expected := &dbschema.Schema{
Tables: []*dbschema.Table{
{
Name: "users",
Columns: []*dbschema.Column{
{Name: "a", Type: "integer", IsNullable: false, Reference: nil},
{Name: "b", Type: "integer", IsNullable: false, Reference: nil},
{Name: "c", Type: "text", IsNullable: true, Reference: nil},
},
PrimaryKey: []string{"a"},
Unique: [][]string{
{"c"},
},
},
{
Name: "names",
Columns: []*dbschema.Column{
{Name: "users_a", Type: "integer", IsNullable: true,
Reference: &dbschema.Reference{
Table: "users",
Column: "a",
OnDelete: "CASCADE",
}},
{Name: "a", Type: "text", IsNullable: false, Reference: nil},
{Name: "x", Type: "text", IsNullable: false, Reference: nil}, // not null, because primary key
{Name: "b", Type: "text", IsNullable: true, Reference: nil},
},
PrimaryKey: []string{"a", "x"},
Unique: [][]string{
{"a", "b"},
{"x"},
},
},
},
Indexes: []*dbschema.Index{
{
Name: "names_a",
Table: "names",
Columns: []string{"a", "b"},
},
},
}
expected.Sort()
schema.Sort()
assert.Equal(t, expected, schema)
}