From 84c2f991d298d7aa5421e8b0b2e34b8183b0127b Mon Sep 17 00:00:00 2001 From: Michal Niewrzal Date: Fri, 15 Feb 2019 17:13:00 +0100 Subject: [PATCH] Create dbutil package for sqlite (#1311) --- internal/dbutil/pgutil/query.go | 2 +- internal/dbutil/sqliteutil/db.go | 78 ++++++++++ internal/dbutil/sqliteutil/query.go | 188 +++++++++++++++++++++++ internal/dbutil/sqliteutil/query_test.go | 104 +++++++++++++ 4 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 internal/dbutil/sqliteutil/db.go create mode 100644 internal/dbutil/sqliteutil/query.go create mode 100644 internal/dbutil/sqliteutil/query_test.go diff --git a/internal/dbutil/pgutil/query.go b/internal/dbutil/pgutil/query.go index 88a0a1682..3e1a4064e 100644 --- a/internal/dbutil/pgutil/query.go +++ b/internal/dbutil/pgutil/query.go @@ -47,7 +47,7 @@ func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { return rows.Err() }() if err != nil { - return schema, err + return nil, err } // find constraints diff --git a/internal/dbutil/sqliteutil/db.go b/internal/dbutil/sqliteutil/db.go new file mode 100644 index 000000000..33f058d96 --- /dev/null +++ b/internal/dbutil/sqliteutil/db.go @@ -0,0 +1,78 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package sqliteutil + +import ( + "database/sql" + "strconv" + + "github.com/zeebo/errs" + + "storj.io/storj/internal/dbutil/dbschema" +) + +// LoadSchemaFromSQL inserts script into connstr and loads schema. +func LoadSchemaFromSQL(script string) (_ *dbschema.Schema, err error) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + return nil, err + } + defer func() { err = errs.Combine(err, db.Close()) }() + + _, err = db.Exec(script) + if err != nil { + return nil, err + } + + return QuerySchema(db) +} + +// LoadSnapshotFromSQL inserts script into connstr and loads schema. +func LoadSnapshotFromSQL(script string) (_ *dbschema.Snapshot, err error) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + return nil, err + } + defer func() { err = errs.Combine(err, db.Close()) }() + + _, err = db.Exec(script) + if err != nil { + return nil, err + } + + snapshot, err := QuerySnapshot(db) + if err != nil { + return nil, err + } + + snapshot.Script = script + return snapshot, nil +} + +// QuerySnapshot loads snapshot from database +func QuerySnapshot(db dbschema.Queryer) (*dbschema.Snapshot, error) { + schema, err := QuerySchema(db) + if err != nil { + return nil, err + } + + data, err := QueryData(db, schema) + if err != nil { + return nil, err + } + + return &dbschema.Snapshot{ + Version: -1, + Schema: schema, + Data: data, + }, err +} + +// QueryData loads all data from tables +func QueryData(db dbschema.Queryer, schema *dbschema.Schema) (*dbschema.Data, error) { + return dbschema.QueryData(db, schema, func(columnName string) string { + quoted := strconv.Quote(columnName) + return `quote(` + quoted + `) as ` + quoted + }) +} diff --git a/internal/dbutil/sqliteutil/query.go b/internal/dbutil/sqliteutil/query.go new file mode 100644 index 000000000..9d9f7be8c --- /dev/null +++ b/internal/dbutil/sqliteutil/query.go @@ -0,0 +1,188 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package sqliteutil + +import ( + "database/sql" + "regexp" + "strings" + + "github.com/zeebo/errs" + + "storj.io/storj/internal/dbutil/dbschema" +) + +type definition struct { + name string + sql string +} + +// QuerySchema loads the schema from sqlite database. +func QuerySchema(db dbschema.Queryer) (*dbschema.Schema, error) { + schema := &dbschema.Schema{} + + tableDefinitions := make([]*definition, 0) + indexDefinitions := make([]*definition, 0) + + // find tables and indexes + err := func() error { + rows, err := db.Query(` + SELECT name, type, sql FROM sqlite_master WHERE sql NOT NULL AND name NOT LIKE 'sqlite_%' + `) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, rows.Close()) }() + + for rows.Next() { + var defName, defType, defSQL string + err := rows.Scan(&defName, &defType, &defSQL) + if err != nil { + return err + } + if defType == "table" { + tableDefinitions = append(tableDefinitions, &definition{name: defName, sql: defSQL}) + } else if defType == "index" { + indexDefinitions = append(indexDefinitions, &definition{name: defName, sql: defSQL}) + } + } + + return rows.Err() + }() + if err != nil { + return nil, err + } + + err = discoverTables(db, schema, tableDefinitions) + if err != nil { + return nil, err + } + + err = discoverIndexes(db, schema, indexDefinitions) + if err != nil { + return nil, err + } + + schema.Sort() + return schema, nil +} + +func discoverTables(db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) { + for _, definition := range tableDefinitions { + table := schema.EnsureTable(definition.name) + + tableRows, err := db.Query(`PRAGMA table_info(` + definition.name + `)`) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, tableRows.Close()) }() + + for tableRows.Next() { + var defaultValue sql.NullString + var index, name, columnType string + var pk int + var notNull bool + err := tableRows.Scan(&index, &name, &columnType, ¬Null, &defaultValue, &pk) + if err != nil { + return err + } + + column := &dbschema.Column{ + Name: name, + Type: columnType, + IsNullable: !notNull && pk == 0, + } + table.AddColumn(column) + if pk > 0 { + if table.PrimaryKey == nil { + table.PrimaryKey = make([]string, 0) + } + table.PrimaryKey = append(table.PrimaryKey, name) + } + + } + + matches := rxUnique.FindAllStringSubmatch(definition.sql, -1) + for _, match := range matches { + // TODO feel this can be done easier + var columns []string + for _, name := range strings.Split(match[1], ",") { + columns = append(columns, strings.TrimSpace(name)) + } + + table.Unique = append(table.Unique, columns) + } + + keysRows, err := db.Query(`PRAGMA foreign_key_list(` + definition.name + `)`) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, keysRows.Close()) }() + + for keysRows.Next() { + var id, sec int + var tableName, from, to, onUpdate, onDelete, match string + err := keysRows.Scan(&id, &sec, &tableName, &from, &to, &onUpdate, &onDelete, &match) + if err != nil { + return err + } + + column, found := table.FindColumn(from) + if found { + if onDelete == "NO ACTION" { + onDelete = "" + } + if onUpdate == "NO ACTION" { + onUpdate = "" + } + column.Reference = &dbschema.Reference{ + Table: tableName, + Column: to, + OnUpdate: onUpdate, + OnDelete: onDelete, + } + } + } + } + return err +} + +func discoverIndexes(db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) { + // TODO improve indexes discovery + for _, definition := range indexDefinitions { + index := &dbschema.Index{ + Name: definition.name, + } + schema.Indexes = append(schema.Indexes, index) + + indexRows, err := db.Query(`PRAGMA index_info(` + definition.name + `)`) + if err != nil { + return err + } + defer func() { err = errs.Combine(err, indexRows.Close()) }() + + for indexRows.Next() { + var name string + var seqno, cid int + err := indexRows.Scan(&seqno, &cid, &name) + if err != nil { + return err + } + + index.Columns = append(index.Columns, name) + } + + matches := rxIndexTable.FindStringSubmatch(definition.sql) + index.Table = strings.TrimSpace(matches[1]) + } + return err +} + +var ( + // matches UNIQUE (a,b) + rxUnique = regexp.MustCompile(`UNIQUE\s*\((.*?)\)`) + + // matches ON (a,b) + rxIndexTable = regexp.MustCompile(`ON\s*(.*)\(`) +) diff --git a/internal/dbutil/sqliteutil/query_test.go b/internal/dbutil/sqliteutil/query_test.go new file mode 100644 index 000000000..c5dd6a882 --- /dev/null +++ b/internal/dbutil/sqliteutil/query_test.go @@ -0,0 +1,104 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package sqliteutil_test + +import ( + "database/sql" + "testing" + + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "storj.io/storj/internal/dbutil/dbschema" + "storj.io/storj/internal/dbutil/sqliteutil" + "storj.io/storj/internal/testcontext" +) + +func TestQuery(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + + defer ctx.Check(db.Close) + + emptySchema, err := sqliteutil.QuerySchema(db) + assert.NoError(t, err) + assert.Equal(t, &dbschema.Schema{}, emptySchema) + + _, err = db.Exec(` + CREATE TABLE users ( + a integer NOT NULL, + b integer NOT NULL, + c text, + UNIQUE (c), + PRIMARY KEY (a) + ); + CREATE TABLE names ( + users_a integer REFERENCES users( a ) ON DELETE CASCADE, + a text NOT NULL, + x text, + b text, + PRIMARY KEY (a, x), + UNIQUE ( x ), + UNIQUE ( a, b ) + ); + CREATE INDEX names_a ON names (a, b); + `) + + require.NoError(t, err) + + schema, err := sqliteutil.QuerySchema(db) + assert.NoError(t, err) + + expected := &dbschema.Schema{ + Tables: []*dbschema.Table{ + { + Name: "users", + Columns: []*dbschema.Column{ + {Name: "a", Type: "integer", IsNullable: false, Reference: nil}, + {Name: "b", Type: "integer", IsNullable: false, Reference: nil}, + {Name: "c", Type: "text", IsNullable: true, Reference: nil}, + }, + PrimaryKey: []string{"a"}, + Unique: [][]string{ + {"c"}, + }, + }, + { + Name: "names", + Columns: []*dbschema.Column{ + {Name: "users_a", Type: "integer", IsNullable: true, + Reference: &dbschema.Reference{ + Table: "users", + Column: "a", + OnDelete: "CASCADE", + }}, + {Name: "a", Type: "text", IsNullable: false, Reference: nil}, + {Name: "x", Type: "text", IsNullable: false, Reference: nil}, // not null, because primary key + {Name: "b", Type: "text", IsNullable: true, Reference: nil}, + }, + PrimaryKey: []string{"a", "x"}, + Unique: [][]string{ + {"a", "b"}, + {"x"}, + }, + }, + }, + Indexes: []*dbschema.Index{ + { + Name: "names_a", + Table: "names", + Columns: []string{"a", "b"}, + }, + }, + } + + expected.Sort() + schema.Sort() + assert.Equal(t, expected, schema) +}