storj/storagenode/storagenodedb/schemagen/main.go

268 lines
5.6 KiB
Go
Raw Normal View History

// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"bytes"
"context"
"flag"
"fmt"
"go/format"
"io"
"os"
"path/filepath"
"sort"
"strings"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/private/dbutil/dbschema"
"storj.io/private/dbutil/sqliteutil"
"storj.io/storj/storagenode/storagenodedb"
)
func main() {
outfile := flag.String("o", "", "output file")
flag.Parse()
ctx := context.Background()
log := zap.L()
out, err := runSchemaGen(ctx, log)
if err != nil {
printWithLines(os.Stderr, out)
fmt.Fprintf(os.Stderr, "%v", err)
os.Exit(1)
}
if *outfile == "" {
fmt.Print(string(out))
} else {
err := os.WriteFile(*outfile, out, 0644)
if err != nil {
fmt.Fprintf(os.Stderr, "%v", err)
os.Exit(1)
}
}
}
func printWithLines(w io.Writer, data []byte) {
for i, line := range strings.Split(string(data), "\n") {
fmt.Fprintf(w, "%3d: %s\n", i, line)
}
}
func runSchemaGen(ctx context.Context, log *zap.Logger) (_ []byte, err error) {
storagePath, err := os.MkdirTemp("", "testdb")
if err != nil {
return nil, errs.New("Error getting test storage path: %+w", err)
}
defer func() {
removeErr := os.RemoveAll(storagePath)
if removeErr != nil {
err = errs.Combine(err, removeErr)
}
}()
db, err := storagenodedb.OpenNew(ctx, log, storagenodedb.Config{
Storage: storagePath,
Info: filepath.Join(storagePath, "piecestore.db"),
Info2: filepath.Join(storagePath, "info.db"),
Pieces: storagePath,
TestingDisableWAL: true,
})
if err != nil {
return nil, errs.New("Error creating new storagenode db: %+w", err)
}
defer func() {
closeErr := db.Close()
if closeErr != nil {
err = errs.Combine(err, closeErr)
}
}()
err = db.MigrateToLatest(ctx)
if err != nil {
return nil, errs.New("Error creating tables for storagenode db: %+w", err)
}
// get schemas
schemaList := []string{}
allSchemas := make(map[string]*dbschema.Schema)
for dbName, dbContainer := range db.SQLDBs {
schemaList = append(schemaList, dbName)
nextDB := dbContainer.GetDB()
schema, err := sqliteutil.QuerySchema(ctx, nextDB)
if err != nil {
return nil, errs.New("Error getting schema for db: %+w", err)
}
// we don't care about changes in versions table
schema.DropTable("versions")
// If tables and indexes of the schema are empty, set to nil
// to help with comparison to the snapshot.
if len(schema.Tables) == 0 {
schema.Tables = nil
}
if len(schema.Indexes) == 0 {
schema.Indexes = nil
}
allSchemas[dbName] = schema
}
var buf bytes.Buffer
printf := func(format string, args ...interface{}) {
if err != nil {
return
}
_, err = fmt.Fprintf(&buf, format, args...)
}
printf(`//lint:file-ignore * generated file
// AUTOGENERATED BY storj.io/storj/storagenode/storagenodedb/schemagen.go
// DO NOT EDIT
package storagenodedb
import "storj.io/private/dbutil/dbschema"
func Schema() map[string]*dbschema.Schema {
return map[string]*dbschema.Schema{
`)
// use a consistent order for the generated file
sort.StringSlice(schemaList).Sort()
for _, schemaName := range schemaList {
schema := allSchemas[schemaName]
(func() {
printf("%q: {\n", schemaName)
defer printf("},\n")
writeErr := writeSchemaGoStruct(&buf, schema)
if writeErr != nil {
err = errs.New("Error writing schema struct: %+w", writeErr)
}
})()
if err != nil {
return nil, err
}
}
// close bracket for returned map
printf("}\n")
// close bracket for Schema() {
printf("}\n")
formatted, err := format.Source(buf.Bytes())
if err != nil {
return buf.Bytes(), errs.New("Error formatting: %+w", err)
}
return formatted, nil
}
func writeSchemaGoStruct(w io.Writer, schema *dbschema.Schema) (err error) {
printf := func(format string, args ...interface{}) {
if err != nil {
return
}
_, err = fmt.Fprintf(w, format, args...)
}
if len(schema.Tables) > 0 {
(func() {
printf("Tables: []*dbschema.Table{\n")
defer printf("},\n")
for _, table := range schema.Tables {
err = writeTableGoStruct(w, table)
if err != nil {
return
}
printf(",\n")
}
})()
}
if len(schema.Indexes) > 0 {
(func() {
printf("Indexes: []*dbschema.Index{\n")
defer printf("},\n")
for _, index := range schema.Indexes {
printf("%v,\n", prettyValue(index))
}
})()
}
return err
}
func writeTableGoStruct(w io.Writer, table *dbschema.Table) (err error) {
printf := func(format string, args ...interface{}) {
if err != nil {
return
}
_, err = fmt.Fprintf(w, format, args...)
}
printf("{\n")
defer printf("}")
printf("Name: %q,\n", table.Name)
if table.PrimaryKey != nil {
printf("PrimaryKey: %#v,\n", table.PrimaryKey)
}
if table.Unique != nil {
printf("Unique: %v,\n", prettyValue(table.Unique))
}
if len(table.Columns) > 0 {
(func() {
printf("Columns: []*dbschema.Column{\n")
defer printf("},\n")
for _, column := range table.Columns {
err = writeColumnGoStruct(w, column)
if err != nil {
return
}
}
})()
}
return err
}
func writeColumnGoStruct(w io.Writer, column *dbschema.Column) (err error) {
printf := func(format string, args ...interface{}) {
if err != nil {
return
}
_, err = fmt.Fprintf(w, format, args...)
}
printf("{\n")
defer printf("},\n")
printf("Name: %q,\n", column.Name)
printf("Type: %q,\n", column.Type)
printf("IsNullable: %t,\n", column.IsNullable)
if column.Reference != nil {
printf("Reference: %#v,\n", column.Reference)
}
return err
}
// prettyValue converts to string without the outer type
// definition.
func prettyValue(v interface{}) string {
s := fmt.Sprintf("%#v", v)
p := strings.Index(s, "{")
return s[p:]
}