satellite/satellitedb: remove sqlite support (#3296)

This commit is contained in:
Egon Elbre 2019-10-19 00:27:57 +03:00 committed by GitHub
parent 89ed997706
commit 3c438f31bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 121 additions and 7356 deletions

View File

@ -290,7 +290,7 @@ func BenchmarkOrders(b *testing.B) {
ctx := testcontext.New(b)
defer ctx.Cleanup()
counts := []int{50, 100, 250, 500, 999} //sqlite limit of 999
counts := []int{50, 100, 250, 500, 1000}
for _, c := range counts {
c := c
satellitedbtest.Bench(b, func(b *testing.B, db satellite.DB) {

View File

@ -8,8 +8,6 @@ import (
"testing"
"time"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -74,15 +72,6 @@ func TestOrder(t *testing.T) {
// TODO: remove dependency on *dbx.DB
dbAccess := db.(interface{ TestDBAccess() *dbx.DB }).TestDBAccess()
var timeConvertPrefix string
switch d := dbAccess.DB.Driver().(type) {
case *sqlite3.SQLiteDriver:
timeConvertPrefix = "datetime("
case *pq.Driver:
timeConvertPrefix = "timezone('utc', "
default:
t.Errorf("Unsupported database type %t", d)
}
err := dbAccess.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
updateList := []struct {
@ -94,7 +83,7 @@ func TestOrder(t *testing.T) {
{olderRepairPath, time.Now().Add(-3 * time.Hour)},
}
for _, item := range updateList {
res, err := tx.Tx.ExecContext(ctx, dbAccess.Rebind(`UPDATE injuredsegments SET attempted = `+timeConvertPrefix+`?) WHERE path = ?`), item.attempted, item.path)
res, err := tx.Tx.ExecContext(ctx, dbAccess.Rebind(`UPDATE injuredsegments SET attempted = timezone('utc', ?) WHERE path = ?`), item.attempted, item.path)
if err != nil {
return err
}

View File

@ -6,11 +6,8 @@ package satellitedb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
@ -47,7 +44,7 @@ const (
-- If there are more than 1 records within the hour, only the latest will be considered
SELECT
va.partner_id,
%v as hours,
date_trunc('hour', bst.interval_start) as hours,
bst.project_id,
bst.bucket_name,
MAX(bst.interval_start) as max_interval
@ -109,9 +106,6 @@ const (
o.project_id,
o.bucket_name;
`
// DB specific date/time truncations
slHour = "datetime(strftime('%Y-%m-%dT%H:00:00', bst.interval_start))"
pqHour = "date_trunc('hour', bst.interval_start)"
)
type attributionDB struct {
@ -156,17 +150,7 @@ func (keys *attributionDB) Insert(ctx context.Context, info *attribution.Info) (
func (keys *attributionDB) QueryAttribution(ctx context.Context, partnerID uuid.UUID, start time.Time, end time.Time) (_ []*attribution.CSVRow, err error) {
defer mon.Task()(&ctx)(&err)
var query string
switch t := keys.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
query = fmt.Sprintf(valueAttrQuery, slHour)
case *pq.Driver:
query = fmt.Sprintf(valueAttrQuery, pqHour)
default:
return nil, Error.New("Unsupported database %t", t)
}
rows, err := keys.db.DB.QueryContext(ctx, keys.db.Rebind(query), partnerID[:], start.UTC(), end.UTC(), partnerID[:], start.UTC(), end.UTC())
rows, err := keys.db.DB.QueryContext(ctx, keys.db.Rebind(valueAttrQuery), partnerID[:], start.UTC(), end.UTC(), partnerID[:], start.UTC(), end.UTC())
if err != nil {
return nil, Error.Wrap(err)
}

View File

@ -76,12 +76,6 @@ func (db *DB) CreateSchema(schema string) error {
// should not be used outside of migration tests.
func (db *DB) TestDBAccess() *dbx.DB { return db.db }
// TestDBAccess for raw database access,
// should not be used outside of tests.
func (db *locked) TestDBAccess() *dbx.DB {
return db.db.(interface{ TestDBAccess() *dbx.DB }).TestDBAccess()
}
// DropSchema drops the named schema
func (db *DB) DropSchema(schema string) error {
return pgutil.DropSchema(db.db, schema)

View File

@ -10,8 +10,8 @@ import (
"github.com/zeebo/errs"
)
//go:generate dbx.v1 schema -d postgres -d sqlite3 satellitedb.dbx .
//go:generate dbx.v1 golang -d postgres -d sqlite3 satellitedb.dbx .
//go:generate dbx.v1 schema -d postgres satellitedb.dbx .
//go:generate dbx.v1 golang -d postgres satellitedb.dbx .
func init() {
// catch dbx errors

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,6 @@ import (
"time"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/zeebo/errs"
"storj.io/storj/pkg/storj"
@ -71,20 +70,6 @@ func (db *gracefulexitDB) GetProgress(ctx context.Context, nodeID storj.NodeID)
func (db *gracefulexitDB) Enqueue(ctx context.Context, items []gracefulexit.TransferQueueItem) (err error) {
defer mon.Task()(&ctx)(&err)
switch t := db.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
statement := db.db.Rebind(
`INSERT INTO graceful_exit_transfer_queue(node_id, path, piece_num, durability_ratio, queued_at)
VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING;`,
)
for _, item := range items {
_, err = db.db.ExecContext(ctx, statement,
item.NodeID.Bytes(), item.Path, item.PieceNum, item.DurabilityRatio, time.Now().UTC())
if err != nil {
return Error.Wrap(err)
}
}
case *pq.Driver:
sort.Slice(items, func(i, k int) bool {
compare := bytes.Compare(items[i].NodeID.Bytes(), items[k].NodeID.Bytes())
if compare == 0 {
@ -104,18 +89,12 @@ func (db *gracefulexitDB) Enqueue(ctx context.Context, items []gracefulexit.Tran
durabilities = append(durabilities, item.DurabilityRatio)
}
_, err := db.db.ExecContext(ctx, `
_, err = db.db.ExecContext(ctx, db.db.Rebind(`
INSERT INTO graceful_exit_transfer_queue(node_id, path, piece_num, durability_ratio, queued_at)
SELECT unnest($1::bytea[]), unnest($2::bytea[]), unnest($3::integer[]), unnest($4::float8[]), $5
ON CONFLICT DO NOTHING;`, postgresNodeIDList(nodeIDs), pq.ByteaArray(paths), pq.Array(pieceNums), pq.Array(durabilities), time.Now().UTC())
if err != nil {
return Error.Wrap(err)
}
default:
return Error.New("Unsupported database %t", t)
}
ON CONFLICT DO NOTHING;`), postgresNodeIDList(nodeIDs), pq.ByteaArray(paths), pq.Array(pieceNums), pq.Array(durabilities), time.Now().UTC())
return nil
return Error.Wrap(err)
}
// UpdateTransferQueueItem creates a graceful exit transfer queue entry.

File diff suppressed because it is too large Load Diff

View File

@ -1,42 +0,0 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"context"
"sync"
"storj.io/storj/satellite/console"
)
// BeginTransaction is a method for opening transaction
func (m *lockedConsole) BeginTx(ctx context.Context) (console.DBTx, error) {
m.Lock()
db, err := m.db.BeginTx(ctx)
txlocked := &lockedConsole{&sync.Mutex{}, db}
return &lockedTx{m, txlocked, db, sync.Once{}}, err
}
// lockedTx extends Database with transaction scope
type lockedTx struct {
parent *lockedConsole
*lockedConsole
tx console.DBTx
once sync.Once
}
// Commit is a method for committing and closing transaction
func (db *lockedTx) Commit() error {
err := db.tx.Commit()
db.once.Do(db.parent.Unlock)
return err
}
// Rollback is a method for rollback and closing transaction
func (db *lockedTx) Rollback() error {
err := db.tx.Rollback()
db.once.Do(db.parent.Unlock)
return err
}

View File

@ -10,13 +10,10 @@ import (
"sort"
"time"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/internal/dbutil/sqliteutil"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
"storj.io/storj/satellite/orders"
@ -64,7 +61,7 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri
)
_, err = db.db.ExecContext(ctx, statement, storageNodeID.Bytes(), serialNumber.Bytes())
if err != nil {
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
if pgutil.IsConstraintError(err) {
return nil, orders.ErrUsingSerialNumber.New("serial number already used")
}
return nil, err
@ -142,42 +139,18 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, projectID u
func (db *ordersDB) UpdateStoragenodeBandwidthAllocation(ctx context.Context, storageNodes []storj.NodeID, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
switch t := db.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
statement := db.db.Rebind(
`INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(storagenode_id, interval_start, action)
DO UPDATE SET allocated = storagenode_bandwidth_rollups.allocated + excluded.allocated`,
)
for _, storageNode := range storageNodes {
_, err = db.db.ExecContext(ctx, statement,
storageNode.Bytes(), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0,
)
if err != nil {
return Error.Wrap(err)
}
}
case *pq.Driver:
// sort nodes to avoid update deadlock
sort.Sort(storj.NodeIDList(storageNodes))
_, err := db.db.ExecContext(ctx, `
_, err = db.db.ExecContext(ctx, db.db.Rebind(`
INSERT INTO storagenode_bandwidth_rollups
(storagenode_id, interval_start, interval_seconds, action, allocated, settled)
SELECT unnest($1::bytea[]), $2, $3, $4, $5, $6
ON CONFLICT(storagenode_id, interval_start, action)
DO UPDATE SET allocated = storagenode_bandwidth_rollups.allocated + excluded.allocated
`, postgresNodeIDList(storageNodes), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0)
if err != nil {
return Error.Wrap(err)
}
default:
return Error.New("Unsupported database %t", t)
}
`), postgresNodeIDList(storageNodes), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0)
return nil
return Error.Wrap(err)
}
// UpdateStoragenodeBandwidthSettle updates 'settled' bandwidth for given storage node for the given intervalStart time

View File

@ -13,7 +13,6 @@ import (
"time"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/zeebo/errs"
monkit "gopkg.in/spacemonkeygo/monkit.v2"
@ -205,83 +204,6 @@ func (cache *overlaycache) queryNodes(ctx context.Context, excludedNodes []storj
func (cache *overlaycache) queryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
switch t := cache.db.DB.Driver().(type) {
case *sqlite3.SQLiteDriver:
return cache.sqliteQueryNodesDistinct(ctx, excludedNodes, excludedIPs, count, safeQuery, distinctIP, args...)
case *pq.Driver:
return cache.postgresQueryNodesDistinct(ctx, excludedNodes, excludedIPs, count, safeQuery, distinctIP, args...)
default:
return []*pb.Node{}, Error.New("Unsupported database %t", t)
}
}
func (cache *overlaycache) sqliteQueryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
if count == 0 {
return nil, nil
}
safeExcludeNodes := ""
if len(excludedNodes) > 0 {
safeExcludeNodes = ` AND id NOT IN (?` + strings.Repeat(", ?", len(excludedNodes)-1) + `)`
for _, id := range excludedNodes {
args = append(args, id.Bytes())
}
}
safeExcludeIPs := ""
if len(excludedIPs) > 0 {
safeExcludeIPs = ` AND last_net NOT IN (?` + strings.Repeat(", ?", len(excludedIPs)-1) + `)`
for _, ip := range excludedIPs {
args = append(args, ip)
}
}
args = append(args, count)
rows, err := cache.db.Query(cache.db.Rebind(`SELECT id, type, address, last_net,
free_bandwidth, free_disk, total_audit_count, audit_success_count,
total_uptime_count, uptime_success_count, disqualified, audit_reputation_alpha,
audit_reputation_beta, uptime_reputation_alpha, uptime_reputation_beta
FROM (SELECT *, Row_number() OVER(PARTITION BY last_net ORDER BY RANDOM()) rn
FROM nodes
`+safeQuery+safeExcludeNodes+safeExcludeIPs+`) n
WHERE rn = 1
ORDER BY RANDOM()
LIMIT ?`), args...)
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
var nodes []*pb.Node
for rows.Next() {
dbNode := &dbx.Node{}
err = rows.Scan(&dbNode.Id, &dbNode.Type,
&dbNode.Address, &dbNode.LastNet, &dbNode.FreeBandwidth, &dbNode.FreeDisk,
&dbNode.TotalAuditCount, &dbNode.AuditSuccessCount,
&dbNode.TotalUptimeCount, &dbNode.UptimeSuccessCount, &dbNode.Disqualified,
&dbNode.AuditReputationAlpha, &dbNode.AuditReputationBeta,
&dbNode.UptimeReputationAlpha, &dbNode.UptimeReputationBeta,
)
if err != nil {
return nil, err
}
dossier, err := convertDBNode(ctx, dbNode)
if err != nil {
return nil, err
}
nodes = append(nodes, &dossier.Node)
}
return nodes, rows.Err()
}
func (cache *overlaycache) postgresQueryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
defer mon.Task()(&ctx)(&err)
if count == 0 {
return nil, nil
}
@ -375,41 +297,18 @@ func (cache *overlaycache) KnownOffline(ctx context.Context, criteria *overlay.N
// get offline nodes
var rows *sql.Rows
switch t := cache.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
args := make([]interface{}, 0, len(nodeIds)+1)
for i := range nodeIds {
args = append(args, nodeIds[i].Bytes())
}
args = append(args, time.Now().Add(-criteria.OnlineWindow))
rows, err = cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE id IN (?`+strings.Repeat(", ?", len(nodeIds)-1)+`)
AND (
last_contact_success < last_contact_failure AND last_contact_success < ?
)
`), args...)
case *pq.Driver:
rows, err = cache.db.Query(`
SELECT id FROM nodes
WHERE id = any($1::bytea[])
AND (
last_contact_success < last_contact_failure AND last_contact_success < $2
)
`, postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
`), postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
)
default:
return nil, Error.New("Unsupported database %t", t)
}
if err != nil {
return nil, err
}
defer func() {
err = errs.Combine(err, rows.Close())
}()
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var id storj.NodeID
@ -432,39 +331,17 @@ func (cache *overlaycache) KnownUnreliableOrOffline(ctx context.Context, criteri
// get reliable and online nodes
var rows *sql.Rows
switch t := cache.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
args := make([]interface{}, 0, len(nodeIds)+3)
for i := range nodeIds {
args = append(args, nodeIds[i].Bytes())
}
args = append(args, time.Now().Add(-criteria.OnlineWindow))
rows, err = cache.db.Query(cache.db.Rebind(`
SELECT id FROM nodes
WHERE id IN (?`+strings.Repeat(", ?", len(nodeIds)-1)+`)
AND disqualified IS NULL
AND (last_contact_success > ? OR last_contact_success > last_contact_failure)
`), args...)
case *pq.Driver:
rows, err = cache.db.Query(`
SELECT id FROM nodes
WHERE id = any($1::bytea[])
AND disqualified IS NULL
AND (last_contact_success > $2 OR last_contact_success > last_contact_failure)
`, postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
`), postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
)
default:
return nil, Error.New("Unsupported database %t", t)
}
if err != nil {
return nil, err
}
defer func() {
err = errs.Combine(err, rows.Close())
}()
defer func() { err = errs.Combine(err, rows.Close()) }()
goodNodes := make(map[storj.NodeID]struct{}, len(nodeIds))
for rows.Next() {
@ -932,19 +809,6 @@ func (cache *overlaycache) UpdatePieceCounts(ctx context.Context, pieceCounts ma
return counts[i].ID.Less(counts[k].ID)
})
switch t := cache.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
err = cache.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
query := tx.Rebind(`UPDATE nodes SET piece_count = ? WHERE id = ?`)
for _, count := range counts {
_, err := tx.Tx.ExecContext(ctx, query, count.Count, count.ID)
if err != nil {
return Error.Wrap(err)
}
}
return nil
})
case *pq.Driver:
var nodeIDs []storj.NodeID
var countNumbers []int64
for _, count := range counts {
@ -960,9 +824,6 @@ func (cache *overlaycache) UpdatePieceCounts(ctx context.Context, pieceCounts ma
) as update
WHERE nodes.id = update.id
`, postgresNodeIDList(nodeIDs), pq.Array(countNumbers))
default:
return Error.New("Unsupported database %t", t)
}
return Error.Wrap(err)
}
@ -1282,16 +1143,9 @@ func buildUpdateStatement(db *dbx.DB, update updateNodeStats) string {
return ""
}
hexNodeID := hex.EncodeToString(update.NodeID.Bytes())
switch db.DB.Driver().(type) {
case *sqlite3.SQLiteDriver:
sql += fmt.Sprintf(" WHERE nodes.id = X'%v';\n", hexNodeID)
sql += fmt.Sprintf("DELETE FROM pending_audits WHERE pending_audits.node_id = X'%v';\n", hexNodeID)
case *pq.Driver:
sql += fmt.Sprintf(" WHERE nodes.id = decode('%v', 'hex');\n", hexNodeID)
sql += fmt.Sprintf("DELETE FROM pending_audits WHERE pending_audits.node_id = decode('%v', 'hex');\n", hexNodeID)
default:
return ""
}
return sql
}
@ -1449,46 +1303,19 @@ func (cache *overlaycache) UpdateCheckIn(ctx context.Context, node overlay.NodeC
return Error.New("error UpdateCheckIn: missing the storage node address")
}
switch t := cache.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
value := pb.Node{
Id: node.NodeID,
Address: node.Address,
LastIp: node.LastIP,
}
err := cache.UpdateAddress(ctx, &value, config)
if err != nil {
return Error.Wrap(err)
}
_, err = cache.UpdateUptime(ctx, node.NodeID, node.IsUp, config.UptimeReputationLambda, config.UptimeReputationWeight, config.UptimeReputationDQ)
if err != nil {
return Error.Wrap(err)
}
pbInfo := pb.InfoResponse{
Operator: node.Operator,
Capacity: node.Capacity,
Type: pb.NodeType_STORAGE,
Version: node.Version,
}
_, err = cache.UpdateNodeInfo(ctx, node.NodeID, &pbInfo)
if err != nil {
return Error.Wrap(err)
}
case *pq.Driver:
// v is a single feedback value that allows us to update both alpha and beta
var v float64 = -1
if node.IsUp {
v = 1
}
uptimeReputationAlpha := config.UptimeReputationLambda*config.UptimeReputationAlpha0 + config.UptimeReputationWeight*(1+v)/2
uptimeReputationBeta := config.UptimeReputationLambda*config.UptimeReputationBeta0 + config.UptimeReputationWeight*(1-v)/2
semVer, err := version.NewSemVer(node.Version.GetVersion())
if err != nil {
return Error.New("unable to convert version to semVer")
}
start := time.Now()
query := `
INSERT INTO nodes
(
@ -1560,10 +1387,6 @@ func (cache *overlaycache) UpdateCheckIn(ctx context.Context, node overlay.NodeC
if err != nil {
return Error.Wrap(err)
}
mon.FloatVal("UpdateCheckIn query execution time (seconds)").Observe(time.Since(start).Seconds())
default:
return Error.New("Unsupported database %t", t)
}
return nil
}

View File

@ -6,13 +6,8 @@ package satellitedb
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"storj.io/storj/internal/dbutil/pgutil"
"storj.io/storj/internal/dbutil/sqliteutil"
"storj.io/storj/pkg/pb"
dbx "storj.io/storj/satellite/satellitedb/dbx"
"storj.io/storj/storage"
@ -26,7 +21,7 @@ func (r *repairQueue) Insert(ctx context.Context, seg *pb.InjuredSegment) (err e
defer mon.Task()(&ctx)(&err)
_, err = r.db.ExecContext(ctx, r.db.Rebind(`INSERT INTO injuredsegments ( path, data ) VALUES ( ?, ? )`), seg.Path, seg)
if err != nil {
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
if pgutil.IsConstraintError(err) {
return nil // quietly fail on reinsert
}
return err
@ -34,7 +29,7 @@ func (r *repairQueue) Insert(ctx context.Context, seg *pb.InjuredSegment) (err e
return nil
}
func (r *repairQueue) postgresSelect(ctx context.Context) (seg *pb.InjuredSegment, err error) {
func (r *repairQueue) Select(ctx context.Context) (seg *pb.InjuredSegment, err error) {
defer mon.Task()(&ctx)(&err)
err = r.db.QueryRowContext(ctx, `
UPDATE injuredsegments SET attempted = timezone('utc', now()) WHERE path = (
@ -42,55 +37,13 @@ func (r *repairQueue) postgresSelect(ctx context.Context) (seg *pb.InjuredSegmen
WHERE attempted IS NULL OR attempted < timezone('utc', now()) - interval '1 hour'
ORDER BY attempted NULLS FIRST FOR UPDATE SKIP LOCKED LIMIT 1
) RETURNING data`).Scan(&seg)
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
}
return
}
func (r *repairQueue) sqliteSelect(ctx context.Context) (seg *pb.InjuredSegment, err error) {
defer mon.Task()(&ctx)(&err)
err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
var path []byte
err = tx.Tx.QueryRowContext(ctx, r.db.Rebind(`
SELECT path, data FROM injuredsegments
WHERE attempted IS NULL
OR attempted < datetime('now','-1 hours')
ORDER BY attempted LIMIT 1`)).Scan(&path, &seg)
if err != nil {
return err
}
res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`UPDATE injuredsegments SET attempted = datetime('now') WHERE path = ?`), path)
if err != nil {
return err
}
count, err := res.RowsAffected()
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected 1, got %d segments deleted", count)
}
return nil
})
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
}
return seg, err
}
func (r *repairQueue) Select(ctx context.Context) (seg *pb.InjuredSegment, err error) {
defer mon.Task()(&ctx)(&err)
switch t := r.db.DB.Driver().(type) {
case *sqlite3.SQLiteDriver:
return r.sqliteSelect(ctx)
case *pq.Driver:
return r.postgresSelect(ctx)
default:
return seg, fmt.Errorf("Unsupported database %t", t)
}
}
func (r *repairQueue) Delete(ctx context.Context, seg *pb.InjuredSegment) (err error) {
defer mon.Task()(&ctx)(&err)
_, err = r.db.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE path = ?`), seg.Path)

View File

@ -8,8 +8,6 @@ import (
"database/sql"
"time"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
@ -91,28 +89,23 @@ func (c *usercredits) Create(ctx context.Context, userCredit console.CreateCredi
result sql.Result
statement string
)
switch t := c.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
statement = `
INSERT INTO user_credits (user_id, offer_id, credits_earned_in_cents, credits_used_in_cents, expires_at, referred_by, type, created_at)
SELECT * FROM (VALUES (?, ?, ?, 0, ?, ?, ?, time('now'))) AS v
WHERE COALESCE((SELECT COUNT(offer_id) FROM user_credits WHERE offer_id = ? AND referred_by IS NOT NULL ) < NULLIF(?, 0) , ?);
`
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement), userCredit.UserID[:], userCredit.OfferID, userCredit.CreditsEarned.Cents(), userCredit.ExpiresAt, referrerID, userCredit.Type, userCredit.OfferID, userCredit.OfferInfo.RedeemableCap, shouldCreate)
case *pq.Driver:
statement = `
INSERT INTO user_credits (user_id, offer_id, credits_earned_in_cents, credits_used_in_cents, expires_at, referred_by, type, created_at)
SELECT * FROM (VALUES (?::bytea, ?::int, ?::int, 0, ?::timestamp, NULLIF(?::bytea, ?::bytea), ?::text, now())) AS v
WHERE COALESCE((SELECT COUNT(offer_id) FROM user_credits WHERE offer_id = ? AND referred_by IS NOT NULL ) < NULLIF(?, 0), ?);
`
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement), userCredit.UserID[:], userCredit.OfferID, userCredit.CreditsEarned.Cents(), userCredit.ExpiresAt, referrerID, new([]byte), userCredit.Type, userCredit.OfferID, userCredit.OfferInfo.RedeemableCap, shouldCreate)
default:
return errs.New("unsupported database: %t", t)
}
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement),
userCredit.UserID[:],
userCredit.OfferID,
userCredit.CreditsEarned.Cents(),
userCredit.ExpiresAt, referrerID, new([]byte),
userCredit.Type,
userCredit.OfferID,
userCredit.OfferInfo.RedeemableCap, shouldCreate)
if err != nil {
// check to see if there's a constraint error
if pgutil.IsConstraintError(err) || err == sqlite3.ErrConstraint {
if pgutil.IsConstraintError(err) {
_, err := dbExec.ExecContext(ctx, c.db.Rebind(`UPDATE offers SET status = ? AND expires_at = ? WHERE id = ?`), rewards.Done, time.Now().UTC(), userCredit.OfferID)
if err != nil {
return errs.Wrap(err)
@ -138,24 +131,11 @@ func (c *usercredits) Create(ctx context.Context, userCredit console.CreateCredi
// UpdateEarnedCredits updates user credits after user activated their account
func (c *usercredits) UpdateEarnedCredits(ctx context.Context, userID uuid.UUID) error {
var statement string
switch t := c.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
statement = `
UPDATE user_credits
SET credits_earned_in_cents =
(SELECT invitee_credit_in_cents FROM offers WHERE id = offer_id)
WHERE user_id = ? AND credits_earned_in_cents = 0`
case *pq.Driver:
statement = `
statement := `
UPDATE user_credits SET credits_earned_in_cents = offers.invitee_credit_in_cents
FROM offers
WHERE user_id = ? AND credits_earned_in_cents = 0 AND offer_id = offers.id
`
default:
return errs.New("Unsupported database %t", t)
}
result, err := c.db.DB.ExecContext(ctx, c.db.Rebind(statement), userID[:])
if err != nil {
@ -215,18 +195,9 @@ func (c *usercredits) UpdateAvailableCredits(ctx context.Context, creditsToCharg
values = append(values, rowIds...)
var statement string
switch t := c.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
statement = generateQuery(len(availableCredits), false)
case *pq.Driver:
statement = generateQuery(len(availableCredits), true)
default:
return creditsToCharge, errs.New("Unsupported database %t", t)
}
statement := generateQuery(len(availableCredits), true)
_, err = tx.Tx.ExecContext(ctx, c.db.Rebind(`UPDATE user_credits SET
credits_used_in_cents = CASE `+statement), values...)
_, err = tx.Tx.ExecContext(ctx, c.db.Rebind(`UPDATE user_credits SET credits_used_in_cents = CASE `+statement), values...)
if err != nil {
return creditsToCharge, errs.Wrap(errs.Combine(err, tx.Rollback()))
}

View File

@ -1,390 +0,0 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build ignore
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"io/ioutil"
"os"
"path"
"sort"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
func main() {
var outputPath string
var packageName string
var typeFullyQualifedName string
flag.StringVar(&outputPath, "o", "", "output file name")
flag.StringVar(&packageName, "p", "", "output package name")
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
flag.Parse()
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
fmt.Println("missing argument")
os.Exit(1)
}
var code Code
code.Imports = map[string]bool{}
code.Ignore = map[string]bool{
"error": true,
}
code.IgnoreMethods = map[string]bool{
"BeginTx": true,
}
code.OutputPackage = packageName
code.Config = &packages.Config{
Mode: packages.LoadAllSyntax,
}
code.Wrapped = map[string]bool{}
code.AdditionalNesting = map[string]int{"Console": 1}
// e.g. storj.io/storj/satellite.DB
p := strings.LastIndexByte(typeFullyQualifedName, '.')
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
code.Type = typeFullyQualifedName[p+1:] // DB
code.QualifiedType = path.Base(code.Package) + "." + code.Type
var err error
code.Roots, err = packages.Load(code.Config, code.Package)
if err != nil {
panic(err)
}
code.PrintLocked()
code.PrintPreamble()
unformatted := code.Bytes()
imports.LocalPrefix = "storj.io"
formatted, err := imports.Process(outputPath, unformatted, nil)
if err != nil {
fmt.Println(string(unformatted))
panic(err)
}
if outputPath == "" {
fmt.Println(string(formatted))
return
}
err = ioutil.WriteFile(outputPath, formatted, 0644)
if err != nil {
panic(err)
}
}
// Methods is the common interface for types having methods.
type Methods interface {
Method(i int) *types.Func
NumMethods() int
}
// Code is the information for generating the code.
type Code struct {
Config *packages.Config
Package string
Type string
QualifiedType string
Roots []*packages.Package
OutputPackage string
Imports map[string]bool
Ignore map[string]bool
IgnoreMethods map[string]bool
Wrapped map[string]bool
AdditionalNesting map[string]int
Preamble bytes.Buffer
Source bytes.Buffer
}
// Bytes returns all code merged together
func (code *Code) Bytes() []byte {
var all bytes.Buffer
all.Write(code.Preamble.Bytes())
all.Write(code.Source.Bytes())
return all.Bytes()
}
// PrintPreamble creates package header and imports.
func (code *Code) PrintPreamble() {
w := &code.Preamble
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
fmt.Fprintf(w, "import (\n")
var imports []string
for imp := range code.Imports {
imports = append(imports, imp)
}
sort.Strings(imports)
for _, imp := range imports {
fmt.Fprintf(w, " %q\n", imp)
}
fmt.Fprintf(w, ")\n\n")
}
// PrintLocked writes locked wrapper and methods.
func (code *Code) PrintLocked() {
code.Imports["sync"] = true
code.Imports["storj.io/statellite"] = true
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
code.Printf("type locked struct {\n")
code.Printf(" sync.Locker\n")
code.Printf(" db %v\n", code.QualifiedType)
code.Printf("}\n\n")
code.Printf("// newLocked returns database wrapped with locker.\n")
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
code.Printf("}\n\n")
// find the satellite.DB type info
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
methods := dbObject.Type().Underlying().(Methods)
for i := 0; i < methods.NumMethods(); i++ {
code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1)
}
}
// Printf writes formatted text to source.
func (code *Code) Printf(format string, a ...interface{}) {
fmt.Fprintf(&code.Source, format, a...)
}
// PrintSignature prints method signature.
func (code *Code) PrintSignature(sig *types.Signature) {
code.PrintSignatureTuple(sig.Params(), true)
if sig.Results().Len() > 0 {
code.Printf(" ")
code.PrintSignatureTuple(sig.Results(), false)
}
}
// PrintSignatureTuple prints method tuple, params or results.
func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) {
code.Printf("(")
defer code.Printf(")")
for i := 0; i < tuple.Len(); i++ {
if i > 0 {
code.Printf(", ")
}
param := tuple.At(i)
if code.PrintName(tuple.At(i), i, needsNames) {
code.Printf(" ")
}
code.PrintType(param.Type())
}
}
// PrintCall prints a call using the specified signature.
func (code *Code) PrintCall(sig *types.Signature) {
code.Printf("(")
defer code.Printf(")")
params := sig.Params()
for i := 0; i < params.Len(); i++ {
if i != 0 {
code.Printf(", ")
}
code.PrintName(params.At(i), i, true)
}
}
// PrintName prints an appropriate name from signature tuple.
func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool {
name := v.Name()
if needsNames && name == "" {
if v.Type().String() == "context.Context" {
code.Printf("ctx")
return true
}
code.Printf("a%d", index)
return true
}
code.Printf("%s", name)
return name != ""
}
// PrintType prints short form of type t.
func (code *Code) PrintType(t types.Type) {
types.WriteType(&code.Source, t, (*types.Package).Name)
}
func typeName(typ types.Type) string {
var body bytes.Buffer
types.WriteType(&body, typ, (*types.Package).Name)
return body.String()
}
// IncludeImports imports all types referenced in the signature.
func (code *Code) IncludeImports(sig *types.Signature) {
var tmp bytes.Buffer
types.WriteSignature(&tmp, sig, func(p *types.Package) string {
code.Imports[p.Path()] = true
return p.Name()
})
}
// NeedsWrapper checks whether method result needs a wrapper type.
func (code *Code) NeedsWrapper(method *types.Func) bool {
if code.IgnoreMethods[method.Name()] {
return false
}
sig := method.Type().Underlying().(*types.Signature)
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
}
// WrapperTypeName returns an appropriate name for the wrapper type.
func (code *Code) WrapperTypeName(method *types.Func) string {
return "locked" + method.Name()
}
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) {
if code.IgnoreMethods[method.Name()] {
return
}
sig := method.Type().Underlying().(*types.Signature)
code.IncludeImports(sig)
doc := strings.TrimSpace(code.MethodDoc(method))
if doc != "" {
for _, line := range strings.Split(doc, "\n") {
code.Printf("// %s\n", line)
}
}
code.Printf("func (m *%s) %s", receiverType, method.Name())
code.PrintSignature(sig)
code.Printf(" {\n")
code.Printf(" m.Lock(); defer m.Unlock()\n")
if !code.NeedsWrapper(method) {
code.Printf(" return m.db.%s", method.Name())
code.PrintCall(sig)
code.Printf("\n")
code.Printf("}\n\n")
return
}
code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method))
code.Printf("m.db.%s", method.Name())
code.PrintCall(sig)
code.Printf("}\n")
code.Printf("}\n\n")
if nestingDepth > 0 {
code.PrintWrapper(method, nestingDepth-1)
}
}
// PrintWrapper prints wrapper for the result type of method.
func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) {
sig := method.Type().Underlying().(*types.Signature)
results := sig.Results()
result := results.At(0).Type()
receiverType := code.WrapperTypeName(method)
if code.Wrapped[receiverType] {
return
}
code.Wrapped[receiverType] = true
code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result))
code.Printf("type %s struct {\n", receiverType)
code.Printf(" sync.Locker\n")
code.Printf(" db %s\n", typeName(result))
code.Printf("}\n\n")
methods := result.Underlying().(Methods)
for i := 0; i < methods.NumMethods(); i++ {
code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth)
}
}
// MethodDoc finds documentation for the specified method.
func (code *Code) MethodDoc(method *types.Func) string {
file := code.FindASTFile(method.Pos())
if file == nil {
return ""
}
path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos())
if !exact {
return ""
}
for _, p := range path {
switch decl := p.(type) {
case *ast.Field:
return decl.Doc.Text()
case *ast.GenDecl:
return decl.Doc.Text()
case *ast.FuncDecl:
return decl.Doc.Text()
}
}
return ""
}
// FindASTFile finds the *ast.File at the specified position.
func (code *Code) FindASTFile(pos token.Pos) *ast.File {
seen := map[*packages.Package]bool{}
// find searches pos recursively from p and its dependencies.
var find func(p *packages.Package) *ast.File
find = func(p *packages.Package) *ast.File {
if seen[p] {
return nil
}
seen[p] = true
for _, file := range p.Syntax {
if file.Pos() <= pos && pos <= file.End() {
return file
}
}
for _, dep := range p.Imports {
if file := find(dep); file != nil {
return file
}
}
return nil
}
for _, root := range code.Roots {
if file := find(root); file != nil {
return file
}
}
return nil
}