satellite/satellitedb: remove sqlite support (#3296)
This commit is contained in:
parent
89ed997706
commit
3c438f31bd
@ -290,7 +290,7 @@ func BenchmarkOrders(b *testing.B) {
|
||||
ctx := testcontext.New(b)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
counts := []int{50, 100, 250, 500, 999} //sqlite limit of 999
|
||||
counts := []int{50, 100, 250, 500, 1000}
|
||||
for _, c := range counts {
|
||||
c := c
|
||||
satellitedbtest.Bench(b, func(b *testing.B, db satellite.DB) {
|
||||
|
@ -8,8 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -74,15 +72,6 @@ func TestOrder(t *testing.T) {
|
||||
|
||||
// TODO: remove dependency on *dbx.DB
|
||||
dbAccess := db.(interface{ TestDBAccess() *dbx.DB }).TestDBAccess()
|
||||
var timeConvertPrefix string
|
||||
switch d := dbAccess.DB.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
timeConvertPrefix = "datetime("
|
||||
case *pq.Driver:
|
||||
timeConvertPrefix = "timezone('utc', "
|
||||
default:
|
||||
t.Errorf("Unsupported database type %t", d)
|
||||
}
|
||||
|
||||
err := dbAccess.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
updateList := []struct {
|
||||
@ -94,7 +83,7 @@ func TestOrder(t *testing.T) {
|
||||
{olderRepairPath, time.Now().Add(-3 * time.Hour)},
|
||||
}
|
||||
for _, item := range updateList {
|
||||
res, err := tx.Tx.ExecContext(ctx, dbAccess.Rebind(`UPDATE injuredsegments SET attempted = `+timeConvertPrefix+`?) WHERE path = ?`), item.attempted, item.path)
|
||||
res, err := tx.Tx.ExecContext(ctx, dbAccess.Rebind(`UPDATE injuredsegments SET attempted = timezone('utc', ?) WHERE path = ?`), item.attempted, item.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -6,11 +6,8 @@ package satellitedb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
@ -47,7 +44,7 @@ const (
|
||||
-- If there are more than 1 records within the hour, only the latest will be considered
|
||||
SELECT
|
||||
va.partner_id,
|
||||
%v as hours,
|
||||
date_trunc('hour', bst.interval_start) as hours,
|
||||
bst.project_id,
|
||||
bst.bucket_name,
|
||||
MAX(bst.interval_start) as max_interval
|
||||
@ -109,9 +106,6 @@ const (
|
||||
o.project_id,
|
||||
o.bucket_name;
|
||||
`
|
||||
// DB specific date/time truncations
|
||||
slHour = "datetime(strftime('%Y-%m-%dT%H:00:00', bst.interval_start))"
|
||||
pqHour = "date_trunc('hour', bst.interval_start)"
|
||||
)
|
||||
|
||||
type attributionDB struct {
|
||||
@ -156,17 +150,7 @@ func (keys *attributionDB) Insert(ctx context.Context, info *attribution.Info) (
|
||||
func (keys *attributionDB) QueryAttribution(ctx context.Context, partnerID uuid.UUID, start time.Time, end time.Time) (_ []*attribution.CSVRow, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var query string
|
||||
switch t := keys.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
query = fmt.Sprintf(valueAttrQuery, slHour)
|
||||
case *pq.Driver:
|
||||
query = fmt.Sprintf(valueAttrQuery, pqHour)
|
||||
default:
|
||||
return nil, Error.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
rows, err := keys.db.DB.QueryContext(ctx, keys.db.Rebind(query), partnerID[:], start.UTC(), end.UTC(), partnerID[:], start.UTC(), end.UTC())
|
||||
rows, err := keys.db.DB.QueryContext(ctx, keys.db.Rebind(valueAttrQuery), partnerID[:], start.UTC(), end.UTC(), partnerID[:], start.UTC(), end.UTC())
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
@ -76,12 +76,6 @@ func (db *DB) CreateSchema(schema string) error {
|
||||
// should not be used outside of migration tests.
|
||||
func (db *DB) TestDBAccess() *dbx.DB { return db.db }
|
||||
|
||||
// TestDBAccess for raw database access,
|
||||
// should not be used outside of tests.
|
||||
func (db *locked) TestDBAccess() *dbx.DB {
|
||||
return db.db.(interface{ TestDBAccess() *dbx.DB }).TestDBAccess()
|
||||
}
|
||||
|
||||
// DropSchema drops the named schema
|
||||
func (db *DB) DropSchema(schema string) error {
|
||||
return pgutil.DropSchema(db.db, schema)
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
//go:generate dbx.v1 schema -d postgres -d sqlite3 satellitedb.dbx .
|
||||
//go:generate dbx.v1 golang -d postgres -d sqlite3 satellitedb.dbx .
|
||||
//go:generate dbx.v1 schema -d postgres satellitedb.dbx .
|
||||
//go:generate dbx.v1 golang -d postgres satellitedb.dbx .
|
||||
|
||||
func init() {
|
||||
// catch dbx errors
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/pkg/storj"
|
||||
@ -71,20 +70,6 @@ func (db *gracefulexitDB) GetProgress(ctx context.Context, nodeID storj.NodeID)
|
||||
func (db *gracefulexitDB) Enqueue(ctx context.Context, items []gracefulexit.TransferQueueItem) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
switch t := db.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement := db.db.Rebind(
|
||||
`INSERT INTO graceful_exit_transfer_queue(node_id, path, piece_num, durability_ratio, queued_at)
|
||||
VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING;`,
|
||||
)
|
||||
for _, item := range items {
|
||||
_, err = db.db.ExecContext(ctx, statement,
|
||||
item.NodeID.Bytes(), item.Path, item.PieceNum, item.DurabilityRatio, time.Now().UTC())
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
}
|
||||
case *pq.Driver:
|
||||
sort.Slice(items, func(i, k int) bool {
|
||||
compare := bytes.Compare(items[i].NodeID.Bytes(), items[k].NodeID.Bytes())
|
||||
if compare == 0 {
|
||||
@ -104,18 +89,12 @@ func (db *gracefulexitDB) Enqueue(ctx context.Context, items []gracefulexit.Tran
|
||||
durabilities = append(durabilities, item.DurabilityRatio)
|
||||
}
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
_, err = db.db.ExecContext(ctx, db.db.Rebind(`
|
||||
INSERT INTO graceful_exit_transfer_queue(node_id, path, piece_num, durability_ratio, queued_at)
|
||||
SELECT unnest($1::bytea[]), unnest($2::bytea[]), unnest($3::integer[]), unnest($4::float8[]), $5
|
||||
ON CONFLICT DO NOTHING;`, postgresNodeIDList(nodeIDs), pq.ByteaArray(paths), pq.Array(pieceNums), pq.Array(durabilities), time.Now().UTC())
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return Error.New("Unsupported database %t", t)
|
||||
}
|
||||
ON CONFLICT DO NOTHING;`), postgresNodeIDList(nodeIDs), pq.ByteaArray(paths), pq.Array(pieceNums), pq.Array(durabilities), time.Now().UTC())
|
||||
|
||||
return nil
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
// UpdateTransferQueueItem creates a graceful exit transfer queue entry.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,42 +0,0 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"storj.io/storj/satellite/console"
|
||||
)
|
||||
|
||||
// BeginTransaction is a method for opening transaction
|
||||
func (m *lockedConsole) BeginTx(ctx context.Context) (console.DBTx, error) {
|
||||
m.Lock()
|
||||
db, err := m.db.BeginTx(ctx)
|
||||
|
||||
txlocked := &lockedConsole{&sync.Mutex{}, db}
|
||||
return &lockedTx{m, txlocked, db, sync.Once{}}, err
|
||||
}
|
||||
|
||||
// lockedTx extends Database with transaction scope
|
||||
type lockedTx struct {
|
||||
parent *lockedConsole
|
||||
*lockedConsole
|
||||
tx console.DBTx
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Commit is a method for committing and closing transaction
|
||||
func (db *lockedTx) Commit() error {
|
||||
err := db.tx.Commit()
|
||||
db.once.Do(db.parent.Unlock)
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback is a method for rollback and closing transaction
|
||||
func (db *lockedTx) Rollback() error {
|
||||
err := db.tx.Rollback()
|
||||
db.once.Do(db.parent.Unlock)
|
||||
return err
|
||||
}
|
@ -10,13 +10,10 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/internal/dbutil/pgutil"
|
||||
"storj.io/storj/internal/dbutil/sqliteutil"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/storj"
|
||||
"storj.io/storj/satellite/orders"
|
||||
@ -64,7 +61,7 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri
|
||||
)
|
||||
_, err = db.db.ExecContext(ctx, statement, storageNodeID.Bytes(), serialNumber.Bytes())
|
||||
if err != nil {
|
||||
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
|
||||
if pgutil.IsConstraintError(err) {
|
||||
return nil, orders.ErrUsingSerialNumber.New("serial number already used")
|
||||
}
|
||||
return nil, err
|
||||
@ -142,42 +139,18 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, projectID u
|
||||
func (db *ordersDB) UpdateStoragenodeBandwidthAllocation(ctx context.Context, storageNodes []storj.NodeID, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
switch t := db.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement := db.db.Rebind(
|
||||
`INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(storagenode_id, interval_start, action)
|
||||
DO UPDATE SET allocated = storagenode_bandwidth_rollups.allocated + excluded.allocated`,
|
||||
)
|
||||
for _, storageNode := range storageNodes {
|
||||
_, err = db.db.ExecContext(ctx, statement,
|
||||
storageNode.Bytes(), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0,
|
||||
)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
case *pq.Driver:
|
||||
// sort nodes to avoid update deadlock
|
||||
sort.Sort(storj.NodeIDList(storageNodes))
|
||||
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
_, err = db.db.ExecContext(ctx, db.db.Rebind(`
|
||||
INSERT INTO storagenode_bandwidth_rollups
|
||||
(storagenode_id, interval_start, interval_seconds, action, allocated, settled)
|
||||
SELECT unnest($1::bytea[]), $2, $3, $4, $5, $6
|
||||
ON CONFLICT(storagenode_id, interval_start, action)
|
||||
DO UPDATE SET allocated = storagenode_bandwidth_rollups.allocated + excluded.allocated
|
||||
`, postgresNodeIDList(storageNodes), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return Error.New("Unsupported database %t", t)
|
||||
}
|
||||
`), postgresNodeIDList(storageNodes), intervalStart, defaultIntervalSeconds, action, uint64(amount), 0)
|
||||
|
||||
return nil
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
// UpdateStoragenodeBandwidthSettle updates 'settled' bandwidth for given storage node for the given intervalStart time
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
"github.com/zeebo/errs"
|
||||
monkit "gopkg.in/spacemonkeygo/monkit.v2"
|
||||
|
||||
@ -205,83 +204,6 @@ func (cache *overlaycache) queryNodes(ctx context.Context, excludedNodes []storj
|
||||
func (cache *overlaycache) queryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
switch t := cache.db.DB.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
return cache.sqliteQueryNodesDistinct(ctx, excludedNodes, excludedIPs, count, safeQuery, distinctIP, args...)
|
||||
case *pq.Driver:
|
||||
return cache.postgresQueryNodesDistinct(ctx, excludedNodes, excludedIPs, count, safeQuery, distinctIP, args...)
|
||||
default:
|
||||
return []*pb.Node{}, Error.New("Unsupported database %t", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *overlaycache) sqliteQueryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if count == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
safeExcludeNodes := ""
|
||||
if len(excludedNodes) > 0 {
|
||||
safeExcludeNodes = ` AND id NOT IN (?` + strings.Repeat(", ?", len(excludedNodes)-1) + `)`
|
||||
for _, id := range excludedNodes {
|
||||
args = append(args, id.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
safeExcludeIPs := ""
|
||||
if len(excludedIPs) > 0 {
|
||||
safeExcludeIPs = ` AND last_net NOT IN (?` + strings.Repeat(", ?", len(excludedIPs)-1) + `)`
|
||||
for _, ip := range excludedIPs {
|
||||
args = append(args, ip)
|
||||
}
|
||||
}
|
||||
|
||||
args = append(args, count)
|
||||
|
||||
rows, err := cache.db.Query(cache.db.Rebind(`SELECT id, type, address, last_net,
|
||||
free_bandwidth, free_disk, total_audit_count, audit_success_count,
|
||||
total_uptime_count, uptime_success_count, disqualified, audit_reputation_alpha,
|
||||
audit_reputation_beta, uptime_reputation_alpha, uptime_reputation_beta
|
||||
FROM (SELECT *, Row_number() OVER(PARTITION BY last_net ORDER BY RANDOM()) rn
|
||||
FROM nodes
|
||||
`+safeQuery+safeExcludeNodes+safeExcludeIPs+`) n
|
||||
WHERE rn = 1
|
||||
ORDER BY RANDOM()
|
||||
LIMIT ?`), args...)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { err = errs.Combine(err, rows.Close()) }()
|
||||
var nodes []*pb.Node
|
||||
for rows.Next() {
|
||||
dbNode := &dbx.Node{}
|
||||
err = rows.Scan(&dbNode.Id, &dbNode.Type,
|
||||
&dbNode.Address, &dbNode.LastNet, &dbNode.FreeBandwidth, &dbNode.FreeDisk,
|
||||
&dbNode.TotalAuditCount, &dbNode.AuditSuccessCount,
|
||||
&dbNode.TotalUptimeCount, &dbNode.UptimeSuccessCount, &dbNode.Disqualified,
|
||||
&dbNode.AuditReputationAlpha, &dbNode.AuditReputationBeta,
|
||||
&dbNode.UptimeReputationAlpha, &dbNode.UptimeReputationBeta,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dossier, err := convertDBNode(ctx, dbNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes = append(nodes, &dossier.Node)
|
||||
}
|
||||
|
||||
return nodes, rows.Err()
|
||||
}
|
||||
|
||||
func (cache *overlaycache) postgresQueryNodesDistinct(ctx context.Context, excludedNodes []storj.NodeID, excludedIPs []string, count int, safeQuery string, distinctIP bool, args ...interface{}) (_ []*pb.Node, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if count == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@ -375,41 +297,18 @@ func (cache *overlaycache) KnownOffline(ctx context.Context, criteria *overlay.N
|
||||
|
||||
// get offline nodes
|
||||
var rows *sql.Rows
|
||||
switch t := cache.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
args := make([]interface{}, 0, len(nodeIds)+1)
|
||||
for i := range nodeIds {
|
||||
args = append(args, nodeIds[i].Bytes())
|
||||
}
|
||||
args = append(args, time.Now().Add(-criteria.OnlineWindow))
|
||||
|
||||
rows, err = cache.db.Query(cache.db.Rebind(`
|
||||
SELECT id FROM nodes
|
||||
WHERE id IN (?`+strings.Repeat(", ?", len(nodeIds)-1)+`)
|
||||
AND (
|
||||
last_contact_success < last_contact_failure AND last_contact_success < ?
|
||||
)
|
||||
`), args...)
|
||||
|
||||
case *pq.Driver:
|
||||
rows, err = cache.db.Query(`
|
||||
SELECT id FROM nodes
|
||||
WHERE id = any($1::bytea[])
|
||||
AND (
|
||||
last_contact_success < last_contact_failure AND last_contact_success < $2
|
||||
)
|
||||
`, postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
|
||||
`), postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
|
||||
)
|
||||
default:
|
||||
return nil, Error.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = errs.Combine(err, rows.Close())
|
||||
}()
|
||||
defer func() { err = errs.Combine(err, rows.Close()) }()
|
||||
|
||||
for rows.Next() {
|
||||
var id storj.NodeID
|
||||
@ -432,39 +331,17 @@ func (cache *overlaycache) KnownUnreliableOrOffline(ctx context.Context, criteri
|
||||
|
||||
// get reliable and online nodes
|
||||
var rows *sql.Rows
|
||||
switch t := cache.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
args := make([]interface{}, 0, len(nodeIds)+3)
|
||||
for i := range nodeIds {
|
||||
args = append(args, nodeIds[i].Bytes())
|
||||
}
|
||||
args = append(args, time.Now().Add(-criteria.OnlineWindow))
|
||||
|
||||
rows, err = cache.db.Query(cache.db.Rebind(`
|
||||
SELECT id FROM nodes
|
||||
WHERE id IN (?`+strings.Repeat(", ?", len(nodeIds)-1)+`)
|
||||
AND disqualified IS NULL
|
||||
AND (last_contact_success > ? OR last_contact_success > last_contact_failure)
|
||||
`), args...)
|
||||
|
||||
case *pq.Driver:
|
||||
rows, err = cache.db.Query(`
|
||||
SELECT id FROM nodes
|
||||
WHERE id = any($1::bytea[])
|
||||
AND disqualified IS NULL
|
||||
AND (last_contact_success > $2 OR last_contact_success > last_contact_failure)
|
||||
`, postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
|
||||
`), postgresNodeIDList(nodeIds), time.Now().Add(-criteria.OnlineWindow),
|
||||
)
|
||||
default:
|
||||
return nil, Error.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = errs.Combine(err, rows.Close())
|
||||
}()
|
||||
defer func() { err = errs.Combine(err, rows.Close()) }()
|
||||
|
||||
goodNodes := make(map[storj.NodeID]struct{}, len(nodeIds))
|
||||
for rows.Next() {
|
||||
@ -932,19 +809,6 @@ func (cache *overlaycache) UpdatePieceCounts(ctx context.Context, pieceCounts ma
|
||||
return counts[i].ID.Less(counts[k].ID)
|
||||
})
|
||||
|
||||
switch t := cache.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
err = cache.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
query := tx.Rebind(`UPDATE nodes SET piece_count = ? WHERE id = ?`)
|
||||
for _, count := range counts {
|
||||
_, err := tx.Tx.ExecContext(ctx, query, count.Count, count.ID)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case *pq.Driver:
|
||||
var nodeIDs []storj.NodeID
|
||||
var countNumbers []int64
|
||||
for _, count := range counts {
|
||||
@ -960,9 +824,6 @@ func (cache *overlaycache) UpdatePieceCounts(ctx context.Context, pieceCounts ma
|
||||
) as update
|
||||
WHERE nodes.id = update.id
|
||||
`, postgresNodeIDList(nodeIDs), pq.Array(countNumbers))
|
||||
default:
|
||||
return Error.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -1282,16 +1143,9 @@ func buildUpdateStatement(db *dbx.DB, update updateNodeStats) string {
|
||||
return ""
|
||||
}
|
||||
hexNodeID := hex.EncodeToString(update.NodeID.Bytes())
|
||||
switch db.DB.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
sql += fmt.Sprintf(" WHERE nodes.id = X'%v';\n", hexNodeID)
|
||||
sql += fmt.Sprintf("DELETE FROM pending_audits WHERE pending_audits.node_id = X'%v';\n", hexNodeID)
|
||||
case *pq.Driver:
|
||||
|
||||
sql += fmt.Sprintf(" WHERE nodes.id = decode('%v', 'hex');\n", hexNodeID)
|
||||
sql += fmt.Sprintf("DELETE FROM pending_audits WHERE pending_audits.node_id = decode('%v', 'hex');\n", hexNodeID)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
@ -1449,46 +1303,19 @@ func (cache *overlaycache) UpdateCheckIn(ctx context.Context, node overlay.NodeC
|
||||
return Error.New("error UpdateCheckIn: missing the storage node address")
|
||||
}
|
||||
|
||||
switch t := cache.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
value := pb.Node{
|
||||
Id: node.NodeID,
|
||||
Address: node.Address,
|
||||
LastIp: node.LastIP,
|
||||
}
|
||||
err := cache.UpdateAddress(ctx, &value, config)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
_, err = cache.UpdateUptime(ctx, node.NodeID, node.IsUp, config.UptimeReputationLambda, config.UptimeReputationWeight, config.UptimeReputationDQ)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
pbInfo := pb.InfoResponse{
|
||||
Operator: node.Operator,
|
||||
Capacity: node.Capacity,
|
||||
Type: pb.NodeType_STORAGE,
|
||||
Version: node.Version,
|
||||
}
|
||||
_, err = cache.UpdateNodeInfo(ctx, node.NodeID, &pbInfo)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
case *pq.Driver:
|
||||
// v is a single feedback value that allows us to update both alpha and beta
|
||||
var v float64 = -1
|
||||
if node.IsUp {
|
||||
v = 1
|
||||
}
|
||||
|
||||
uptimeReputationAlpha := config.UptimeReputationLambda*config.UptimeReputationAlpha0 + config.UptimeReputationWeight*(1+v)/2
|
||||
uptimeReputationBeta := config.UptimeReputationLambda*config.UptimeReputationBeta0 + config.UptimeReputationWeight*(1-v)/2
|
||||
semVer, err := version.NewSemVer(node.Version.GetVersion())
|
||||
if err != nil {
|
||||
return Error.New("unable to convert version to semVer")
|
||||
}
|
||||
start := time.Now()
|
||||
|
||||
query := `
|
||||
INSERT INTO nodes
|
||||
(
|
||||
@ -1560,10 +1387,6 @@ func (cache *overlaycache) UpdateCheckIn(ctx context.Context, node overlay.NodeC
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
mon.FloatVal("UpdateCheckIn query execution time (seconds)").Observe(time.Since(start).Seconds())
|
||||
default:
|
||||
return Error.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -6,13 +6,8 @@ package satellitedb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
|
||||
"storj.io/storj/internal/dbutil/pgutil"
|
||||
"storj.io/storj/internal/dbutil/sqliteutil"
|
||||
"storj.io/storj/pkg/pb"
|
||||
dbx "storj.io/storj/satellite/satellitedb/dbx"
|
||||
"storj.io/storj/storage"
|
||||
@ -26,7 +21,7 @@ func (r *repairQueue) Insert(ctx context.Context, seg *pb.InjuredSegment) (err e
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
_, err = r.db.ExecContext(ctx, r.db.Rebind(`INSERT INTO injuredsegments ( path, data ) VALUES ( ?, ? )`), seg.Path, seg)
|
||||
if err != nil {
|
||||
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
|
||||
if pgutil.IsConstraintError(err) {
|
||||
return nil // quietly fail on reinsert
|
||||
}
|
||||
return err
|
||||
@ -34,7 +29,7 @@ func (r *repairQueue) Insert(ctx context.Context, seg *pb.InjuredSegment) (err e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repairQueue) postgresSelect(ctx context.Context) (seg *pb.InjuredSegment, err error) {
|
||||
func (r *repairQueue) Select(ctx context.Context) (seg *pb.InjuredSegment, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
UPDATE injuredsegments SET attempted = timezone('utc', now()) WHERE path = (
|
||||
@ -42,55 +37,13 @@ func (r *repairQueue) postgresSelect(ctx context.Context) (seg *pb.InjuredSegmen
|
||||
WHERE attempted IS NULL OR attempted < timezone('utc', now()) - interval '1 hour'
|
||||
ORDER BY attempted NULLS FIRST FOR UPDATE SKIP LOCKED LIMIT 1
|
||||
) RETURNING data`).Scan(&seg)
|
||||
if err == sql.ErrNoRows {
|
||||
err = storage.ErrEmptyQueue.New("")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *repairQueue) sqliteSelect(ctx context.Context) (seg *pb.InjuredSegment, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
var path []byte
|
||||
err = tx.Tx.QueryRowContext(ctx, r.db.Rebind(`
|
||||
SELECT path, data FROM injuredsegments
|
||||
WHERE attempted IS NULL
|
||||
OR attempted < datetime('now','-1 hours')
|
||||
ORDER BY attempted LIMIT 1`)).Scan(&path, &seg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`UPDATE injuredsegments SET attempted = datetime('now') WHERE path = ?`), path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected 1, got %d segments deleted", count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
err = storage.ErrEmptyQueue.New("")
|
||||
}
|
||||
return seg, err
|
||||
}
|
||||
|
||||
func (r *repairQueue) Select(ctx context.Context) (seg *pb.InjuredSegment, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
switch t := r.db.DB.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
return r.sqliteSelect(ctx)
|
||||
case *pq.Driver:
|
||||
return r.postgresSelect(ctx)
|
||||
default:
|
||||
return seg, fmt.Errorf("Unsupported database %t", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *repairQueue) Delete(ctx context.Context, seg *pb.InjuredSegment) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
_, err = r.db.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE path = ?`), seg.Path)
|
||||
|
@ -8,8 +8,6 @@ import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
@ -91,28 +89,23 @@ func (c *usercredits) Create(ctx context.Context, userCredit console.CreateCredi
|
||||
result sql.Result
|
||||
statement string
|
||||
)
|
||||
switch t := c.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement = `
|
||||
INSERT INTO user_credits (user_id, offer_id, credits_earned_in_cents, credits_used_in_cents, expires_at, referred_by, type, created_at)
|
||||
SELECT * FROM (VALUES (?, ?, ?, 0, ?, ?, ?, time('now'))) AS v
|
||||
WHERE COALESCE((SELECT COUNT(offer_id) FROM user_credits WHERE offer_id = ? AND referred_by IS NOT NULL ) < NULLIF(?, 0) , ?);
|
||||
`
|
||||
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement), userCredit.UserID[:], userCredit.OfferID, userCredit.CreditsEarned.Cents(), userCredit.ExpiresAt, referrerID, userCredit.Type, userCredit.OfferID, userCredit.OfferInfo.RedeemableCap, shouldCreate)
|
||||
case *pq.Driver:
|
||||
statement = `
|
||||
INSERT INTO user_credits (user_id, offer_id, credits_earned_in_cents, credits_used_in_cents, expires_at, referred_by, type, created_at)
|
||||
SELECT * FROM (VALUES (?::bytea, ?::int, ?::int, 0, ?::timestamp, NULLIF(?::bytea, ?::bytea), ?::text, now())) AS v
|
||||
WHERE COALESCE((SELECT COUNT(offer_id) FROM user_credits WHERE offer_id = ? AND referred_by IS NOT NULL ) < NULLIF(?, 0), ?);
|
||||
`
|
||||
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement), userCredit.UserID[:], userCredit.OfferID, userCredit.CreditsEarned.Cents(), userCredit.ExpiresAt, referrerID, new([]byte), userCredit.Type, userCredit.OfferID, userCredit.OfferInfo.RedeemableCap, shouldCreate)
|
||||
default:
|
||||
return errs.New("unsupported database: %t", t)
|
||||
}
|
||||
result, err = dbExec.ExecContext(ctx, c.db.Rebind(statement),
|
||||
userCredit.UserID[:],
|
||||
userCredit.OfferID,
|
||||
userCredit.CreditsEarned.Cents(),
|
||||
userCredit.ExpiresAt, referrerID, new([]byte),
|
||||
userCredit.Type,
|
||||
userCredit.OfferID,
|
||||
userCredit.OfferInfo.RedeemableCap, shouldCreate)
|
||||
|
||||
if err != nil {
|
||||
// check to see if there's a constraint error
|
||||
if pgutil.IsConstraintError(err) || err == sqlite3.ErrConstraint {
|
||||
if pgutil.IsConstraintError(err) {
|
||||
_, err := dbExec.ExecContext(ctx, c.db.Rebind(`UPDATE offers SET status = ? AND expires_at = ? WHERE id = ?`), rewards.Done, time.Now().UTC(), userCredit.OfferID)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
@ -138,24 +131,11 @@ func (c *usercredits) Create(ctx context.Context, userCredit console.CreateCredi
|
||||
|
||||
// UpdateEarnedCredits updates user credits after user activated their account
|
||||
func (c *usercredits) UpdateEarnedCredits(ctx context.Context, userID uuid.UUID) error {
|
||||
var statement string
|
||||
|
||||
switch t := c.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement = `
|
||||
UPDATE user_credits
|
||||
SET credits_earned_in_cents =
|
||||
(SELECT invitee_credit_in_cents FROM offers WHERE id = offer_id)
|
||||
WHERE user_id = ? AND credits_earned_in_cents = 0`
|
||||
case *pq.Driver:
|
||||
statement = `
|
||||
statement := `
|
||||
UPDATE user_credits SET credits_earned_in_cents = offers.invitee_credit_in_cents
|
||||
FROM offers
|
||||
WHERE user_id = ? AND credits_earned_in_cents = 0 AND offer_id = offers.id
|
||||
`
|
||||
default:
|
||||
return errs.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
result, err := c.db.DB.ExecContext(ctx, c.db.Rebind(statement), userID[:])
|
||||
if err != nil {
|
||||
@ -215,18 +195,9 @@ func (c *usercredits) UpdateAvailableCredits(ctx context.Context, creditsToCharg
|
||||
|
||||
values = append(values, rowIds...)
|
||||
|
||||
var statement string
|
||||
switch t := c.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement = generateQuery(len(availableCredits), false)
|
||||
case *pq.Driver:
|
||||
statement = generateQuery(len(availableCredits), true)
|
||||
default:
|
||||
return creditsToCharge, errs.New("Unsupported database %t", t)
|
||||
}
|
||||
statement := generateQuery(len(availableCredits), true)
|
||||
|
||||
_, err = tx.Tx.ExecContext(ctx, c.db.Rebind(`UPDATE user_credits SET
|
||||
credits_used_in_cents = CASE `+statement), values...)
|
||||
_, err = tx.Tx.ExecContext(ctx, c.db.Rebind(`UPDATE user_credits SET credits_used_in_cents = CASE `+statement), values...)
|
||||
if err != nil {
|
||||
return creditsToCharge, errs.Wrap(errs.Combine(err, tx.Rollback()))
|
||||
}
|
||||
|
@ -1,390 +0,0 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/packages"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var outputPath string
|
||||
var packageName string
|
||||
var typeFullyQualifedName string
|
||||
|
||||
flag.StringVar(&outputPath, "o", "", "output file name")
|
||||
flag.StringVar(&packageName, "p", "", "output package name")
|
||||
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
|
||||
flag.Parse()
|
||||
|
||||
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
|
||||
fmt.Println("missing argument")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var code Code
|
||||
|
||||
code.Imports = map[string]bool{}
|
||||
code.Ignore = map[string]bool{
|
||||
"error": true,
|
||||
}
|
||||
code.IgnoreMethods = map[string]bool{
|
||||
"BeginTx": true,
|
||||
}
|
||||
code.OutputPackage = packageName
|
||||
code.Config = &packages.Config{
|
||||
Mode: packages.LoadAllSyntax,
|
||||
}
|
||||
code.Wrapped = map[string]bool{}
|
||||
code.AdditionalNesting = map[string]int{"Console": 1}
|
||||
|
||||
// e.g. storj.io/storj/satellite.DB
|
||||
p := strings.LastIndexByte(typeFullyQualifedName, '.')
|
||||
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
|
||||
code.Type = typeFullyQualifedName[p+1:] // DB
|
||||
code.QualifiedType = path.Base(code.Package) + "." + code.Type
|
||||
|
||||
var err error
|
||||
code.Roots, err = packages.Load(code.Config, code.Package)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code.PrintLocked()
|
||||
code.PrintPreamble()
|
||||
|
||||
unformatted := code.Bytes()
|
||||
|
||||
imports.LocalPrefix = "storj.io"
|
||||
formatted, err := imports.Process(outputPath, unformatted, nil)
|
||||
if err != nil {
|
||||
fmt.Println(string(unformatted))
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if outputPath == "" {
|
||||
fmt.Println(string(formatted))
|
||||
return
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(outputPath, formatted, 0644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Methods is the common interface for types having methods.
|
||||
type Methods interface {
|
||||
Method(i int) *types.Func
|
||||
NumMethods() int
|
||||
}
|
||||
|
||||
// Code is the information for generating the code.
|
||||
type Code struct {
|
||||
Config *packages.Config
|
||||
Package string
|
||||
Type string
|
||||
QualifiedType string
|
||||
Roots []*packages.Package
|
||||
|
||||
OutputPackage string
|
||||
|
||||
Imports map[string]bool
|
||||
Ignore map[string]bool
|
||||
IgnoreMethods map[string]bool
|
||||
Wrapped map[string]bool
|
||||
AdditionalNesting map[string]int
|
||||
|
||||
Preamble bytes.Buffer
|
||||
Source bytes.Buffer
|
||||
}
|
||||
|
||||
// Bytes returns all code merged together
|
||||
func (code *Code) Bytes() []byte {
|
||||
var all bytes.Buffer
|
||||
all.Write(code.Preamble.Bytes())
|
||||
all.Write(code.Source.Bytes())
|
||||
return all.Bytes()
|
||||
}
|
||||
|
||||
// PrintPreamble creates package header and imports.
|
||||
func (code *Code) PrintPreamble() {
|
||||
w := &code.Preamble
|
||||
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
|
||||
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
|
||||
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
|
||||
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
|
||||
fmt.Fprintf(w, "import (\n")
|
||||
|
||||
var imports []string
|
||||
for imp := range code.Imports {
|
||||
imports = append(imports, imp)
|
||||
}
|
||||
sort.Strings(imports)
|
||||
for _, imp := range imports {
|
||||
fmt.Fprintf(w, " %q\n", imp)
|
||||
}
|
||||
fmt.Fprintf(w, ")\n\n")
|
||||
}
|
||||
|
||||
// PrintLocked writes locked wrapper and methods.
|
||||
func (code *Code) PrintLocked() {
|
||||
code.Imports["sync"] = true
|
||||
code.Imports["storj.io/statellite"] = true
|
||||
|
||||
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
|
||||
code.Printf("type locked struct {\n")
|
||||
code.Printf(" sync.Locker\n")
|
||||
code.Printf(" db %v\n", code.QualifiedType)
|
||||
code.Printf("}\n\n")
|
||||
|
||||
code.Printf("// newLocked returns database wrapped with locker.\n")
|
||||
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
|
||||
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
|
||||
code.Printf("}\n\n")
|
||||
|
||||
// find the satellite.DB type info
|
||||
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
|
||||
methods := dbObject.Type().Underlying().(Methods)
|
||||
|
||||
for i := 0; i < methods.NumMethods(); i++ {
|
||||
code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Printf writes formatted text to source.
|
||||
func (code *Code) Printf(format string, a ...interface{}) {
|
||||
fmt.Fprintf(&code.Source, format, a...)
|
||||
}
|
||||
|
||||
// PrintSignature prints method signature.
|
||||
func (code *Code) PrintSignature(sig *types.Signature) {
|
||||
code.PrintSignatureTuple(sig.Params(), true)
|
||||
if sig.Results().Len() > 0 {
|
||||
code.Printf(" ")
|
||||
code.PrintSignatureTuple(sig.Results(), false)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintSignatureTuple prints method tuple, params or results.
|
||||
func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) {
|
||||
code.Printf("(")
|
||||
defer code.Printf(")")
|
||||
|
||||
for i := 0; i < tuple.Len(); i++ {
|
||||
if i > 0 {
|
||||
code.Printf(", ")
|
||||
}
|
||||
|
||||
param := tuple.At(i)
|
||||
if code.PrintName(tuple.At(i), i, needsNames) {
|
||||
code.Printf(" ")
|
||||
}
|
||||
code.PrintType(param.Type())
|
||||
}
|
||||
}
|
||||
|
||||
// PrintCall prints a call using the specified signature.
|
||||
func (code *Code) PrintCall(sig *types.Signature) {
|
||||
code.Printf("(")
|
||||
defer code.Printf(")")
|
||||
|
||||
params := sig.Params()
|
||||
for i := 0; i < params.Len(); i++ {
|
||||
if i != 0 {
|
||||
code.Printf(", ")
|
||||
}
|
||||
code.PrintName(params.At(i), i, true)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintName prints an appropriate name from signature tuple.
|
||||
func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool {
|
||||
name := v.Name()
|
||||
if needsNames && name == "" {
|
||||
if v.Type().String() == "context.Context" {
|
||||
code.Printf("ctx")
|
||||
return true
|
||||
}
|
||||
code.Printf("a%d", index)
|
||||
return true
|
||||
}
|
||||
code.Printf("%s", name)
|
||||
return name != ""
|
||||
}
|
||||
|
||||
// PrintType prints short form of type t.
|
||||
func (code *Code) PrintType(t types.Type) {
|
||||
types.WriteType(&code.Source, t, (*types.Package).Name)
|
||||
}
|
||||
|
||||
func typeName(typ types.Type) string {
|
||||
var body bytes.Buffer
|
||||
types.WriteType(&body, typ, (*types.Package).Name)
|
||||
return body.String()
|
||||
}
|
||||
|
||||
// IncludeImports imports all types referenced in the signature.
|
||||
func (code *Code) IncludeImports(sig *types.Signature) {
|
||||
var tmp bytes.Buffer
|
||||
types.WriteSignature(&tmp, sig, func(p *types.Package) string {
|
||||
code.Imports[p.Path()] = true
|
||||
return p.Name()
|
||||
})
|
||||
}
|
||||
|
||||
// NeedsWrapper checks whether method result needs a wrapper type.
|
||||
func (code *Code) NeedsWrapper(method *types.Func) bool {
|
||||
if code.IgnoreMethods[method.Name()] {
|
||||
return false
|
||||
}
|
||||
|
||||
sig := method.Type().Underlying().(*types.Signature)
|
||||
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
|
||||
}
|
||||
|
||||
// WrapperTypeName returns an appropriate name for the wrapper type.
|
||||
func (code *Code) WrapperTypeName(method *types.Func) string {
|
||||
return "locked" + method.Name()
|
||||
}
|
||||
|
||||
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
|
||||
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) {
|
||||
if code.IgnoreMethods[method.Name()] {
|
||||
return
|
||||
}
|
||||
|
||||
sig := method.Type().Underlying().(*types.Signature)
|
||||
code.IncludeImports(sig)
|
||||
|
||||
doc := strings.TrimSpace(code.MethodDoc(method))
|
||||
if doc != "" {
|
||||
for _, line := range strings.Split(doc, "\n") {
|
||||
code.Printf("// %s\n", line)
|
||||
}
|
||||
}
|
||||
code.Printf("func (m *%s) %s", receiverType, method.Name())
|
||||
code.PrintSignature(sig)
|
||||
code.Printf(" {\n")
|
||||
|
||||
code.Printf(" m.Lock(); defer m.Unlock()\n")
|
||||
if !code.NeedsWrapper(method) {
|
||||
code.Printf(" return m.db.%s", method.Name())
|
||||
code.PrintCall(sig)
|
||||
code.Printf("\n")
|
||||
code.Printf("}\n\n")
|
||||
return
|
||||
}
|
||||
|
||||
code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method))
|
||||
code.Printf("m.db.%s", method.Name())
|
||||
code.PrintCall(sig)
|
||||
code.Printf("}\n")
|
||||
code.Printf("}\n\n")
|
||||
|
||||
if nestingDepth > 0 {
|
||||
code.PrintWrapper(method, nestingDepth-1)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintWrapper prints wrapper for the result type of method.
|
||||
func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) {
|
||||
sig := method.Type().Underlying().(*types.Signature)
|
||||
results := sig.Results()
|
||||
result := results.At(0).Type()
|
||||
|
||||
receiverType := code.WrapperTypeName(method)
|
||||
|
||||
if code.Wrapped[receiverType] {
|
||||
return
|
||||
}
|
||||
code.Wrapped[receiverType] = true
|
||||
|
||||
code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result))
|
||||
code.Printf("type %s struct {\n", receiverType)
|
||||
code.Printf(" sync.Locker\n")
|
||||
code.Printf(" db %s\n", typeName(result))
|
||||
code.Printf("}\n\n")
|
||||
|
||||
methods := result.Underlying().(Methods)
|
||||
for i := 0; i < methods.NumMethods(); i++ {
|
||||
code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth)
|
||||
}
|
||||
}
|
||||
|
||||
// MethodDoc finds documentation for the specified method.
|
||||
func (code *Code) MethodDoc(method *types.Func) string {
|
||||
file := code.FindASTFile(method.Pos())
|
||||
if file == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos())
|
||||
if !exact {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, p := range path {
|
||||
switch decl := p.(type) {
|
||||
case *ast.Field:
|
||||
return decl.Doc.Text()
|
||||
case *ast.GenDecl:
|
||||
return decl.Doc.Text()
|
||||
case *ast.FuncDecl:
|
||||
return decl.Doc.Text()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FindASTFile finds the *ast.File at the specified position.
|
||||
func (code *Code) FindASTFile(pos token.Pos) *ast.File {
|
||||
seen := map[*packages.Package]bool{}
|
||||
|
||||
// find searches pos recursively from p and its dependencies.
|
||||
var find func(p *packages.Package) *ast.File
|
||||
find = func(p *packages.Package) *ast.File {
|
||||
if seen[p] {
|
||||
return nil
|
||||
}
|
||||
seen[p] = true
|
||||
|
||||
for _, file := range p.Syntax {
|
||||
if file.Pos() <= pos && pos <= file.End() {
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
for _, dep := range p.Imports {
|
||||
if file := find(dep); file != nil {
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, root := range code.Roots {
|
||||
if file := find(root); file != nil {
|
||||
return file
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user