Repair queue isolation level fix (#1466)
Implemented custom SQLite and Postgres Repairqueue Dequeue handlers
This commit is contained in:
parent
7dbdf89f1a
commit
665fd33e3c
@ -53,5 +53,5 @@ type DB interface {
|
||||
// QueryPaymentInfo queries StatDB, Accounting Rollup on nodeID
|
||||
QueryPaymentInfo(ctx context.Context, start time.Time, end time.Time) ([]*CSVRow, error)
|
||||
// DeleteRawBefore deletes all raw tallies prior to some time
|
||||
DeleteRawBefore(latestRollup time.Time) error
|
||||
DeleteRawBefore(ctx context.Context, latestRollup time.Time) error
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ func (r *Rollup) RollupRaws(ctx context.Context) error {
|
||||
var rolledUpRawsHaveBeenSaved bool
|
||||
//todo: write files to disk or whatever we decide to do here
|
||||
if rolledUpRawsHaveBeenSaved {
|
||||
return Error.Wrap(r.db.DeleteRawBefore(latestTally))
|
||||
return Error.Wrap(r.db.DeleteRawBefore(ctx, latestTally))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -88,14 +88,14 @@ func (t *Tally) Tally(ctx context.Context) error {
|
||||
} else {
|
||||
//remove expired records
|
||||
now := time.Now()
|
||||
_, err = t.bwAgreementDB.GetExpired(tallyEnd, now)
|
||||
_, err = t.bwAgreementDB.GetExpired(ctx, tallyEnd, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var expiredOrdersHaveBeenSaved bool
|
||||
//todo: write files to disk or whatever we decide to do here
|
||||
if expiredOrdersHaveBeenSaved {
|
||||
err = t.bwAgreementDB.DeleteExpired(tallyEnd, now)
|
||||
err = t.bwAgreementDB.DeleteExpired(ctx, tallyEnd, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -29,15 +29,15 @@ func TestBandwidthDBAgreement(t *testing.T) {
|
||||
snID, err := testidentity.NewTestIdentity(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID))
|
||||
require.Error(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID))
|
||||
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID))
|
||||
require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID))
|
||||
require.Error(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID))
|
||||
require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID))
|
||||
testGetTotals(ctx, t, db.BandwidthAgreement(), snID)
|
||||
testGetUplinkStats(ctx, t, db.BandwidthAgreement(), upID)
|
||||
})
|
||||
}
|
||||
|
||||
func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
|
||||
func testSaveOrder(ctx context.Context, t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
|
||||
serialNum string, upID, snID *identity.FullIdentity) error {
|
||||
rba := &pb.Order{
|
||||
PayerAllocation: pb.OrderLimit{
|
||||
@ -48,7 +48,7 @@ func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
|
||||
Total: 1000,
|
||||
StorageNodeId: snID.ID,
|
||||
}
|
||||
return b.SaveOrder(rba)
|
||||
return b.SaveOrder(ctx, rba)
|
||||
}
|
||||
|
||||
func testGetUplinkStats(ctx context.Context, t *testing.T, b bwagreement.DB, upID *identity.FullIdentity) {
|
||||
|
@ -56,15 +56,15 @@ type SavedOrder struct {
|
||||
// DB stores orders for accounting purposes
|
||||
type DB interface {
|
||||
// SaveOrder saves an order for accounting
|
||||
SaveOrder(*pb.Order) error
|
||||
SaveOrder(context.Context, *pb.Order) error
|
||||
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
|
||||
GetTotals(context.Context, time.Time, time.Time) (map[storj.NodeID][]int64, error)
|
||||
//GetTotals returns stats about an uplink
|
||||
GetUplinkStats(context.Context, time.Time, time.Time) ([]UplinkStat, error)
|
||||
//GetExpired gets orders that are expired and were created before some time
|
||||
GetExpired(time.Time, time.Time) ([]SavedOrder, error)
|
||||
GetExpired(context.Context, time.Time, time.Time) ([]SavedOrder, error)
|
||||
//DeleteExpired deletes orders that are expired and were created before some time
|
||||
DeleteExpired(time.Time, time.Time) error
|
||||
DeleteExpired(context.Context, time.Time, time.Time) error
|
||||
}
|
||||
|
||||
// Server is an implementation of the pb.BandwidthServer interface
|
||||
@ -112,7 +112,7 @@ func (s *Server) BandwidthAgreements(ctx context.Context, rba *pb.Order) (reply
|
||||
}
|
||||
|
||||
//save and return rersults
|
||||
if err = s.bwdb.SaveOrder(rba); err != nil {
|
||||
if err = s.bwdb.SaveOrder(ctx, rba); err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") ||
|
||||
strings.Contains(err.Error(), "violates unique constraint") {
|
||||
return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err))
|
||||
|
@ -90,8 +90,6 @@ func TestSequential(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
t.Skip("logic is broken on database side")
|
||||
|
||||
satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
@ -101,7 +99,6 @@ func TestParallel(t *testing.T) {
|
||||
errs := make(chan error, N*2)
|
||||
entries := make(chan *pb.InjuredSegment, N*2)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(N)
|
||||
// Add to queue concurrently
|
||||
for i := 0; i < N; i++ {
|
||||
@ -115,7 +112,6 @@ func TestParallel(t *testing.T) {
|
||||
errs <- err
|
||||
}
|
||||
}(i)
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
|
50
pkg/pb/scannerValuer.go
Normal file
50
pkg/pb/scannerValuer.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
proto "github.com/gogo/protobuf/proto"
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
// Error is bootstrap web error type
|
||||
var scanError = errs.Class("Protobuf Scanner")
|
||||
var valueError = errs.Class("Protobuf Valuer")
|
||||
|
||||
//scan automatically converts database []byte to proto.Messages
|
||||
func scan(msg proto.Message, value interface{}) error {
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return scanError.New("%t was %t, expected []bytes", msg, value)
|
||||
}
|
||||
return scanError.Wrap(proto.Unmarshal(bytes, msg))
|
||||
}
|
||||
|
||||
//value automatically converts proto.Messages to database []byte
|
||||
func value(msg proto.Message) (driver.Value, error) {
|
||||
value, err := proto.Marshal(msg)
|
||||
return value, valueError.Wrap(err)
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *InjuredSegment) Scan(value interface{}) error {
|
||||
return scan(n, value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n InjuredSegment) Value() (driver.Value, error) {
|
||||
return value(&n)
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *Order) Scan(value interface{}) error {
|
||||
return scan(n, value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n Order) Value() (driver.Value, error) {
|
||||
return value(&n)
|
||||
}
|
@ -177,7 +177,7 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
|
||||
LEFT JOIN nodes n ON n.id = r.node_id
|
||||
LEFT JOIN overlay_cache_nodes o ON n.id = o.node_id
|
||||
ORDER BY n.id`
|
||||
rows, err := db.db.DB.Query(db.db.Rebind(sqlStmt), start.UTC(), end.UTC())
|
||||
rows, err := db.db.DB.QueryContext(ctx, db.db.Rebind(sqlStmt), start.UTC(), end.UTC())
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
@ -206,8 +206,8 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
|
||||
}
|
||||
|
||||
// DeleteRawBefore deletes all raw tallies prior to some time
|
||||
func (db *accountingDB) DeleteRawBefore(latestRollup time.Time) error {
|
||||
func (db *accountingDB) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
|
||||
var deleteRawSQL = `DELETE FROM accounting_raws WHERE interval_end_time < ?`
|
||||
_, err := db.db.DB.Exec(db.db.Rebind(deleteRawSQL), latestRollup)
|
||||
_, err := db.db.DB.ExecContext(ctx, db.db.Rebind(deleteRawSQL), latestRollup)
|
||||
return err
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ type bandwidthagreement struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
|
||||
func (b *bandwidthagreement) SaveOrder(rba *pb.Order) (err error) {
|
||||
func (b *bandwidthagreement) SaveOrder(ctx context.Context, rba *pb.Order) (err error) {
|
||||
var saveOrderSQL = `INSERT INTO bwagreements ( serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at ) VALUES ( ?, ?, ?, ?, ?, ?, ? )`
|
||||
_, err = b.db.DB.Exec(b.db.Rebind(saveOrderSQL),
|
||||
_, err = b.db.DB.ExecContext(ctx, b.db.Rebind(saveOrderSQL),
|
||||
rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String(),
|
||||
rba.StorageNodeId,
|
||||
rba.PayerAllocation.UplinkId,
|
||||
@ -43,7 +43,7 @@ func (b *bandwidthagreement) GetUplinkStats(ctx context.Context, from, to time.T
|
||||
FROM bwagreements WHERE created_at > ?
|
||||
AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`,
|
||||
pb.BandwidthAction_PUT, pb.BandwidthAction_GET)
|
||||
rows, err := b.db.DB.Query(b.db.Rebind(uplinkSQL), from.UTC(), to.UTC())
|
||||
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(uplinkSQL), from.UTC(), to.UTC())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -71,7 +71,7 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
|
||||
GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT,
|
||||
pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT,
|
||||
pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR)
|
||||
rows, err := b.db.DB.Query(b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC())
|
||||
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -92,10 +92,10 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
|
||||
}
|
||||
|
||||
//GetExpired gets orders that are expired and were created before some time
|
||||
func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) {
|
||||
func (b *bandwidthagreement) GetExpired(ctx context.Context, before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) {
|
||||
var getExpiredSQL = `SELECT serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at
|
||||
FROM bwagreements WHERE created_at < ? AND expires_at < ?`
|
||||
rows, err := b.db.DB.Query(b.db.Rebind(getExpiredSQL), before, expiredAt)
|
||||
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getExpiredSQL), before, expiredAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -112,8 +112,8 @@ func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (
|
||||
}
|
||||
|
||||
//DeleteExpired deletes orders that are expired and were created before some time
|
||||
func (b *bandwidthagreement) DeleteExpired(before time.Time, expiredAt time.Time) error {
|
||||
func (b *bandwidthagreement) DeleteExpired(ctx context.Context, before time.Time, expiredAt time.Time) error {
|
||||
var deleteExpiredSQL = `DELETE FROM bwagreements WHERE created_at < ? AND expires_at < ?`
|
||||
_, err := b.db.DB.Exec(b.db.Rebind(deleteExpiredSQL), before, expiredAt)
|
||||
_, err := b.db.DB.ExecContext(ctx, b.db.Rebind(deleteExpiredSQL), before, expiredAt)
|
||||
return err
|
||||
}
|
||||
|
@ -51,10 +51,10 @@ type lockedAccounting struct {
|
||||
}
|
||||
|
||||
// DeleteRawBefore deletes all raw tallies prior to some time
|
||||
func (m *lockedAccounting) DeleteRawBefore(latestRollup time.Time) error {
|
||||
func (m *lockedAccounting) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.DeleteRawBefore(latestRollup)
|
||||
return m.db.DeleteRawBefore(ctx, latestRollup)
|
||||
}
|
||||
|
||||
// GetRaw retrieves all raw tallies
|
||||
@ -120,17 +120,17 @@ type lockedBandwidthAgreement struct {
|
||||
}
|
||||
|
||||
// DeleteExpired deletes orders that are expired and were created before some time
|
||||
func (m *lockedBandwidthAgreement) DeleteExpired(a0 time.Time, a1 time.Time) error {
|
||||
func (m *lockedBandwidthAgreement) DeleteExpired(ctx context.Context, a1 time.Time, a2 time.Time) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.DeleteExpired(a0, a1)
|
||||
return m.db.DeleteExpired(ctx, a1, a2)
|
||||
}
|
||||
|
||||
// GetExpired gets orders that are expired and were created before some time
|
||||
func (m *lockedBandwidthAgreement) GetExpired(a0 time.Time, a1 time.Time) ([]bwagreement.SavedOrder, error) {
|
||||
func (m *lockedBandwidthAgreement) GetExpired(ctx context.Context, a1 time.Time, a2 time.Time) ([]bwagreement.SavedOrder, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.GetExpired(a0, a1)
|
||||
return m.db.GetExpired(ctx, a1, a2)
|
||||
}
|
||||
|
||||
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
|
||||
@ -148,10 +148,10 @@ func (m *lockedBandwidthAgreement) GetUplinkStats(ctx context.Context, a1 time.T
|
||||
}
|
||||
|
||||
// SaveOrder saves an order for accounting
|
||||
func (m *lockedBandwidthAgreement) SaveOrder(a0 *pb.RenterBandwidthAllocation) error {
|
||||
func (m *lockedBandwidthAgreement) SaveOrder(ctx context.Context, a1 *pb.RenterBandwidthAllocation) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.SaveOrder(a0)
|
||||
return m.db.SaveOrder(ctx, a1)
|
||||
}
|
||||
|
||||
// CertDB returns database for storing uplink's public key & ID
|
||||
|
@ -5,11 +5,14 @@ package satellitedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/utils"
|
||||
dbx "storj.io/storj/satellite/satellitedb/dbx"
|
||||
"storj.io/storj/storage"
|
||||
)
|
||||
@ -31,38 +34,53 @@ func (r *repairQueue) Enqueue(ctx context.Context, seg *pb.InjuredSegment) error
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *repairQueue) Dequeue(ctx context.Context) (pb.InjuredSegment, error) {
|
||||
// TODO: fix out of order issue
|
||||
tx, err := r.db.Open(ctx)
|
||||
if err != nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(err)
|
||||
func (r *repairQueue) postgresDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
|
||||
err = r.db.DB.QueryRowContext(ctx, `
|
||||
DELETE FROM injuredsegments
|
||||
WHERE id = ( SELECT id FROM injuredsegments ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1 )
|
||||
RETURNING info
|
||||
`).Scan(&seg)
|
||||
if err == sql.ErrNoRows {
|
||||
err = storage.ErrEmptyQueue.New("")
|
||||
}
|
||||
return seg, err
|
||||
}
|
||||
|
||||
res, err := tx.First_Injuredsegment(ctx)
|
||||
if err != nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback()))
|
||||
} else if res == nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(storage.ErrEmptyQueue.New(""), tx.Rollback()))
|
||||
func (r *repairQueue) sqliteDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
|
||||
err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
var id int64
|
||||
err = tx.Tx.QueryRowContext(ctx, `SELECT id, info FROM injuredsegments ORDER BY id LIMIT 1`).Scan(&id, &seg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE id = ?`), id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count != 1 {
|
||||
return fmt.Errorf("Expected 1, got %d segments deleted", count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == sql.ErrNoRows {
|
||||
err = storage.ErrEmptyQueue.New("")
|
||||
}
|
||||
return seg, err
|
||||
}
|
||||
|
||||
deleted, err := tx.Delete_Injuredsegment_By_Id(
|
||||
ctx,
|
||||
dbx.Injuredsegment_Id(res.Id),
|
||||
)
|
||||
if err != nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback()))
|
||||
} else if !deleted {
|
||||
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(Error.New("Injured segment not deleted"), tx.Rollback()))
|
||||
func (r *repairQueue) Dequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
|
||||
switch t := r.db.DB.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
return r.sqliteDequeue(ctx)
|
||||
case *pq.Driver:
|
||||
return r.postgresDequeue(ctx)
|
||||
default:
|
||||
return seg, fmt.Errorf("Unsupported database %t", t)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
seg := &pb.InjuredSegment{}
|
||||
if err = proto.Unmarshal(res.Info, seg); err != nil {
|
||||
return pb.InjuredSegment{}, Error.Wrap(err)
|
||||
}
|
||||
return *seg, nil
|
||||
}
|
||||
|
||||
func (r *repairQueue) Peekqueue(ctx context.Context, limit int) ([]pb.InjuredSegment, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user