Repair queue isolation level fix (#1466)

Implemented custom SQLite and Postgres Repairqueue Dequeue handlers
This commit is contained in:
Bill Thorp 2019-03-14 17:12:47 -04:00 committed by GitHub
parent 7dbdf89f1a
commit 665fd33e3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 64 deletions

View File

@ -53,5 +53,5 @@ type DB interface {
// QueryPaymentInfo queries StatDB, Accounting Rollup on nodeID
QueryPaymentInfo(ctx context.Context, start time.Time, end time.Time) ([]*CSVRow, error)
// DeleteRawBefore deletes all raw tallies prior to some time
DeleteRawBefore(latestRollup time.Time) error
DeleteRawBefore(ctx context.Context, latestRollup time.Time) error
}

View File

@ -117,7 +117,7 @@ func (r *Rollup) RollupRaws(ctx context.Context) error {
var rolledUpRawsHaveBeenSaved bool
//todo: write files to disk or whatever we decide to do here
if rolledUpRawsHaveBeenSaved {
return Error.Wrap(r.db.DeleteRawBefore(latestTally))
return Error.Wrap(r.db.DeleteRawBefore(ctx, latestTally))
}
return nil
}

View File

@ -88,14 +88,14 @@ func (t *Tally) Tally(ctx context.Context) error {
} else {
//remove expired records
now := time.Now()
_, err = t.bwAgreementDB.GetExpired(tallyEnd, now)
_, err = t.bwAgreementDB.GetExpired(ctx, tallyEnd, now)
if err != nil {
return err
}
var expiredOrdersHaveBeenSaved bool
//todo: write files to disk or whatever we decide to do here
if expiredOrdersHaveBeenSaved {
err = t.bwAgreementDB.DeleteExpired(tallyEnd, now)
err = t.bwAgreementDB.DeleteExpired(ctx, tallyEnd, now)
if err != nil {
return err
}

View File

@ -29,15 +29,15 @@ func TestBandwidthDBAgreement(t *testing.T) {
snID, err := testidentity.NewTestIdentity(ctx)
require.NoError(t, err)
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID))
require.Error(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID))
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID))
require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID))
require.Error(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID))
require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID))
testGetTotals(ctx, t, db.BandwidthAgreement(), snID)
testGetUplinkStats(ctx, t, db.BandwidthAgreement(), upID)
})
}
func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
func testSaveOrder(ctx context.Context, t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
serialNum string, upID, snID *identity.FullIdentity) error {
rba := &pb.Order{
PayerAllocation: pb.OrderLimit{
@ -48,7 +48,7 @@ func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
Total: 1000,
StorageNodeId: snID.ID,
}
return b.SaveOrder(rba)
return b.SaveOrder(ctx, rba)
}
func testGetUplinkStats(ctx context.Context, t *testing.T, b bwagreement.DB, upID *identity.FullIdentity) {

View File

@ -56,15 +56,15 @@ type SavedOrder struct {
// DB stores orders for accounting purposes
type DB interface {
// SaveOrder saves an order for accounting
SaveOrder(*pb.Order) error
SaveOrder(context.Context, *pb.Order) error
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
GetTotals(context.Context, time.Time, time.Time) (map[storj.NodeID][]int64, error)
//GetTotals returns stats about an uplink
GetUplinkStats(context.Context, time.Time, time.Time) ([]UplinkStat, error)
//GetExpired gets orders that are expired and were created before some time
GetExpired(time.Time, time.Time) ([]SavedOrder, error)
GetExpired(context.Context, time.Time, time.Time) ([]SavedOrder, error)
//DeleteExpired deletes orders that are expired and were created before some time
DeleteExpired(time.Time, time.Time) error
DeleteExpired(context.Context, time.Time, time.Time) error
}
// Server is an implementation of the pb.BandwidthServer interface
@ -112,7 +112,7 @@ func (s *Server) BandwidthAgreements(ctx context.Context, rba *pb.Order) (reply
}
//save and return rersults
if err = s.bwdb.SaveOrder(rba); err != nil {
if err = s.bwdb.SaveOrder(ctx, rba); err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") ||
strings.Contains(err.Error(), "violates unique constraint") {
return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err))

View File

@ -90,8 +90,6 @@ func TestSequential(t *testing.T) {
}
func TestParallel(t *testing.T) {
t.Skip("logic is broken on database side")
satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
@ -101,7 +99,6 @@ func TestParallel(t *testing.T) {
errs := make(chan error, N*2)
entries := make(chan *pb.InjuredSegment, N*2)
var wg sync.WaitGroup
wg.Add(N)
// Add to queue concurrently
for i := 0; i < N; i++ {
@ -115,7 +112,6 @@ func TestParallel(t *testing.T) {
errs <- err
}
}(i)
}
wg.Wait()

50
pkg/pb/scannerValuer.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package pb
import (
"database/sql/driver"
proto "github.com/gogo/protobuf/proto"
"github.com/zeebo/errs"
)
// Error is bootstrap web error type
var scanError = errs.Class("Protobuf Scanner")
var valueError = errs.Class("Protobuf Valuer")
//scan automatically converts database []byte to proto.Messages
func scan(msg proto.Message, value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return scanError.New("%t was %t, expected []bytes", msg, value)
}
return scanError.Wrap(proto.Unmarshal(bytes, msg))
}
//value automatically converts proto.Messages to database []byte
func value(msg proto.Message) (driver.Value, error) {
value, err := proto.Marshal(msg)
return value, valueError.Wrap(err)
}
// Scan implements the Scanner interface.
func (n *InjuredSegment) Scan(value interface{}) error {
return scan(n, value)
}
// Value implements the driver Valuer interface.
func (n InjuredSegment) Value() (driver.Value, error) {
return value(&n)
}
// Scan implements the Scanner interface.
func (n *Order) Scan(value interface{}) error {
return scan(n, value)
}
// Value implements the driver Valuer interface.
func (n Order) Value() (driver.Value, error) {
return value(&n)
}

View File

@ -177,7 +177,7 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
LEFT JOIN nodes n ON n.id = r.node_id
LEFT JOIN overlay_cache_nodes o ON n.id = o.node_id
ORDER BY n.id`
rows, err := db.db.DB.Query(db.db.Rebind(sqlStmt), start.UTC(), end.UTC())
rows, err := db.db.DB.QueryContext(ctx, db.db.Rebind(sqlStmt), start.UTC(), end.UTC())
if err != nil {
return nil, Error.Wrap(err)
}
@ -206,8 +206,8 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
}
// DeleteRawBefore deletes all raw tallies prior to some time
func (db *accountingDB) DeleteRawBefore(latestRollup time.Time) error {
func (db *accountingDB) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
var deleteRawSQL = `DELETE FROM accounting_raws WHERE interval_end_time < ?`
_, err := db.db.DB.Exec(db.db.Rebind(deleteRawSQL), latestRollup)
_, err := db.db.DB.ExecContext(ctx, db.db.Rebind(deleteRawSQL), latestRollup)
return err
}

View File

@ -20,9 +20,9 @@ type bandwidthagreement struct {
db *dbx.DB
}
func (b *bandwidthagreement) SaveOrder(rba *pb.Order) (err error) {
func (b *bandwidthagreement) SaveOrder(ctx context.Context, rba *pb.Order) (err error) {
var saveOrderSQL = `INSERT INTO bwagreements ( serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at ) VALUES ( ?, ?, ?, ?, ?, ?, ? )`
_, err = b.db.DB.Exec(b.db.Rebind(saveOrderSQL),
_, err = b.db.DB.ExecContext(ctx, b.db.Rebind(saveOrderSQL),
rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String(),
rba.StorageNodeId,
rba.PayerAllocation.UplinkId,
@ -43,7 +43,7 @@ func (b *bandwidthagreement) GetUplinkStats(ctx context.Context, from, to time.T
FROM bwagreements WHERE created_at > ?
AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`,
pb.BandwidthAction_PUT, pb.BandwidthAction_GET)
rows, err := b.db.DB.Query(b.db.Rebind(uplinkSQL), from.UTC(), to.UTC())
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(uplinkSQL), from.UTC(), to.UTC())
if err != nil {
return nil, err
}
@ -71,7 +71,7 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT,
pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT,
pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR)
rows, err := b.db.DB.Query(b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC())
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC())
if err != nil {
return nil, err
}
@ -92,10 +92,10 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
}
//GetExpired gets orders that are expired and were created before some time
func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) {
func (b *bandwidthagreement) GetExpired(ctx context.Context, before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) {
var getExpiredSQL = `SELECT serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at
FROM bwagreements WHERE created_at < ? AND expires_at < ?`
rows, err := b.db.DB.Query(b.db.Rebind(getExpiredSQL), before, expiredAt)
rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getExpiredSQL), before, expiredAt)
if err != nil {
return nil, err
}
@ -112,8 +112,8 @@ func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (
}
//DeleteExpired deletes orders that are expired and were created before some time
func (b *bandwidthagreement) DeleteExpired(before time.Time, expiredAt time.Time) error {
func (b *bandwidthagreement) DeleteExpired(ctx context.Context, before time.Time, expiredAt time.Time) error {
var deleteExpiredSQL = `DELETE FROM bwagreements WHERE created_at < ? AND expires_at < ?`
_, err := b.db.DB.Exec(b.db.Rebind(deleteExpiredSQL), before, expiredAt)
_, err := b.db.DB.ExecContext(ctx, b.db.Rebind(deleteExpiredSQL), before, expiredAt)
return err
}

View File

@ -51,10 +51,10 @@ type lockedAccounting struct {
}
// DeleteRawBefore deletes all raw tallies prior to some time
func (m *lockedAccounting) DeleteRawBefore(latestRollup time.Time) error {
func (m *lockedAccounting) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
m.Lock()
defer m.Unlock()
return m.db.DeleteRawBefore(latestRollup)
return m.db.DeleteRawBefore(ctx, latestRollup)
}
// GetRaw retrieves all raw tallies
@ -120,17 +120,17 @@ type lockedBandwidthAgreement struct {
}
// DeleteExpired deletes orders that are expired and were created before some time
func (m *lockedBandwidthAgreement) DeleteExpired(a0 time.Time, a1 time.Time) error {
func (m *lockedBandwidthAgreement) DeleteExpired(ctx context.Context, a1 time.Time, a2 time.Time) error {
m.Lock()
defer m.Unlock()
return m.db.DeleteExpired(a0, a1)
return m.db.DeleteExpired(ctx, a1, a2)
}
// GetExpired gets orders that are expired and were created before some time
func (m *lockedBandwidthAgreement) GetExpired(a0 time.Time, a1 time.Time) ([]bwagreement.SavedOrder, error) {
func (m *lockedBandwidthAgreement) GetExpired(ctx context.Context, a1 time.Time, a2 time.Time) ([]bwagreement.SavedOrder, error) {
m.Lock()
defer m.Unlock()
return m.db.GetExpired(a0, a1)
return m.db.GetExpired(ctx, a1, a2)
}
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
@ -148,10 +148,10 @@ func (m *lockedBandwidthAgreement) GetUplinkStats(ctx context.Context, a1 time.T
}
// SaveOrder saves an order for accounting
func (m *lockedBandwidthAgreement) SaveOrder(a0 *pb.RenterBandwidthAllocation) error {
func (m *lockedBandwidthAgreement) SaveOrder(ctx context.Context, a1 *pb.RenterBandwidthAllocation) error {
m.Lock()
defer m.Unlock()
return m.db.SaveOrder(a0)
return m.db.SaveOrder(ctx, a1)
}
// CertDB returns database for storing uplink's public key & ID

View File

@ -5,11 +5,14 @@ package satellitedb
import (
"context"
"database/sql"
"fmt"
"github.com/golang/protobuf/proto"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/utils"
dbx "storj.io/storj/satellite/satellitedb/dbx"
"storj.io/storj/storage"
)
@ -31,38 +34,53 @@ func (r *repairQueue) Enqueue(ctx context.Context, seg *pb.InjuredSegment) error
return err
}
func (r *repairQueue) Dequeue(ctx context.Context) (pb.InjuredSegment, error) {
// TODO: fix out of order issue
tx, err := r.db.Open(ctx)
if err != nil {
return pb.InjuredSegment{}, Error.Wrap(err)
func (r *repairQueue) postgresDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
err = r.db.DB.QueryRowContext(ctx, `
DELETE FROM injuredsegments
WHERE id = ( SELECT id FROM injuredsegments ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1 )
RETURNING info
`).Scan(&seg)
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
}
return seg, err
}
res, err := tx.First_Injuredsegment(ctx)
if err != nil {
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback()))
} else if res == nil {
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(storage.ErrEmptyQueue.New(""), tx.Rollback()))
func (r *repairQueue) sqliteDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
var id int64
err = tx.Tx.QueryRowContext(ctx, `SELECT id, info FROM injuredsegments ORDER BY id LIMIT 1`).Scan(&id, &seg)
if err != nil {
return err
}
res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE id = ?`), id)
if err != nil {
return err
}
count, err := res.RowsAffected()
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected 1, got %d segments deleted", count)
}
return nil
})
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
}
return seg, err
}
deleted, err := tx.Delete_Injuredsegment_By_Id(
ctx,
dbx.Injuredsegment_Id(res.Id),
)
if err != nil {
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback()))
} else if !deleted {
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(Error.New("Injured segment not deleted"), tx.Rollback()))
func (r *repairQueue) Dequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
switch t := r.db.DB.Driver().(type) {
case *sqlite3.SQLiteDriver:
return r.sqliteDequeue(ctx)
case *pq.Driver:
return r.postgresDequeue(ctx)
default:
return seg, fmt.Errorf("Unsupported database %t", t)
}
if err := tx.Commit(); err != nil {
return pb.InjuredSegment{}, Error.Wrap(err)
}
seg := &pb.InjuredSegment{}
if err = proto.Unmarshal(res.Info, seg); err != nil {
return pb.InjuredSegment{}, Error.Wrap(err)
}
return *seg, nil
}
func (r *repairQueue) Peekqueue(ctx context.Context, limit int) ([]pb.InjuredSegment, error) {