Repair queue isolation level fix (#1466)

Implemented custom SQLite and Postgres Repairqueue Dequeue handlers
This commit is contained in:
Bill Thorp 2019-03-14 17:12:47 -04:00 committed by GitHub
parent 7dbdf89f1a
commit 665fd33e3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 64 deletions

View File

@ -53,5 +53,5 @@ type DB interface {
// QueryPaymentInfo queries StatDB, Accounting Rollup on nodeID // QueryPaymentInfo queries StatDB, Accounting Rollup on nodeID
QueryPaymentInfo(ctx context.Context, start time.Time, end time.Time) ([]*CSVRow, error) QueryPaymentInfo(ctx context.Context, start time.Time, end time.Time) ([]*CSVRow, error)
// DeleteRawBefore deletes all raw tallies prior to some time // DeleteRawBefore deletes all raw tallies prior to some time
DeleteRawBefore(latestRollup time.Time) error DeleteRawBefore(ctx context.Context, latestRollup time.Time) error
} }

View File

@ -117,7 +117,7 @@ func (r *Rollup) RollupRaws(ctx context.Context) error {
var rolledUpRawsHaveBeenSaved bool var rolledUpRawsHaveBeenSaved bool
//todo: write files to disk or whatever we decide to do here //todo: write files to disk or whatever we decide to do here
if rolledUpRawsHaveBeenSaved { if rolledUpRawsHaveBeenSaved {
return Error.Wrap(r.db.DeleteRawBefore(latestTally)) return Error.Wrap(r.db.DeleteRawBefore(ctx, latestTally))
} }
return nil return nil
} }

View File

@ -88,14 +88,14 @@ func (t *Tally) Tally(ctx context.Context) error {
} else { } else {
//remove expired records //remove expired records
now := time.Now() now := time.Now()
_, err = t.bwAgreementDB.GetExpired(tallyEnd, now) _, err = t.bwAgreementDB.GetExpired(ctx, tallyEnd, now)
if err != nil { if err != nil {
return err return err
} }
var expiredOrdersHaveBeenSaved bool var expiredOrdersHaveBeenSaved bool
//todo: write files to disk or whatever we decide to do here //todo: write files to disk or whatever we decide to do here
if expiredOrdersHaveBeenSaved { if expiredOrdersHaveBeenSaved {
err = t.bwAgreementDB.DeleteExpired(tallyEnd, now) err = t.bwAgreementDB.DeleteExpired(ctx, tallyEnd, now)
if err != nil { if err != nil {
return err return err
} }

View File

@ -29,15 +29,15 @@ func TestBandwidthDBAgreement(t *testing.T) {
snID, err := testidentity.NewTestIdentity(ctx) snID, err := testidentity.NewTestIdentity(ctx)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID)) require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID))
require.Error(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID)) require.Error(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID))
require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID)) require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID))
testGetTotals(ctx, t, db.BandwidthAgreement(), snID) testGetTotals(ctx, t, db.BandwidthAgreement(), snID)
testGetUplinkStats(ctx, t, db.BandwidthAgreement(), upID) testGetUplinkStats(ctx, t, db.BandwidthAgreement(), upID)
}) })
} }
func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction, func testSaveOrder(ctx context.Context, t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
serialNum string, upID, snID *identity.FullIdentity) error { serialNum string, upID, snID *identity.FullIdentity) error {
rba := &pb.Order{ rba := &pb.Order{
PayerAllocation: pb.OrderLimit{ PayerAllocation: pb.OrderLimit{
@ -48,7 +48,7 @@ func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction,
Total: 1000, Total: 1000,
StorageNodeId: snID.ID, StorageNodeId: snID.ID,
} }
return b.SaveOrder(rba) return b.SaveOrder(ctx, rba)
} }
func testGetUplinkStats(ctx context.Context, t *testing.T, b bwagreement.DB, upID *identity.FullIdentity) { func testGetUplinkStats(ctx context.Context, t *testing.T, b bwagreement.DB, upID *identity.FullIdentity) {

View File

@ -56,15 +56,15 @@ type SavedOrder struct {
// DB stores orders for accounting purposes // DB stores orders for accounting purposes
type DB interface { type DB interface {
// SaveOrder saves an order for accounting // SaveOrder saves an order for accounting
SaveOrder(*pb.Order) error SaveOrder(context.Context, *pb.Order) error
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range // GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
GetTotals(context.Context, time.Time, time.Time) (map[storj.NodeID][]int64, error) GetTotals(context.Context, time.Time, time.Time) (map[storj.NodeID][]int64, error)
//GetTotals returns stats about an uplink //GetTotals returns stats about an uplink
GetUplinkStats(context.Context, time.Time, time.Time) ([]UplinkStat, error) GetUplinkStats(context.Context, time.Time, time.Time) ([]UplinkStat, error)
//GetExpired gets orders that are expired and were created before some time //GetExpired gets orders that are expired and were created before some time
GetExpired(time.Time, time.Time) ([]SavedOrder, error) GetExpired(context.Context, time.Time, time.Time) ([]SavedOrder, error)
//DeleteExpired deletes orders that are expired and were created before some time //DeleteExpired deletes orders that are expired and were created before some time
DeleteExpired(time.Time, time.Time) error DeleteExpired(context.Context, time.Time, time.Time) error
} }
// Server is an implementation of the pb.BandwidthServer interface // Server is an implementation of the pb.BandwidthServer interface
@ -112,7 +112,7 @@ func (s *Server) BandwidthAgreements(ctx context.Context, rba *pb.Order) (reply
} }
//save and return rersults //save and return rersults
if err = s.bwdb.SaveOrder(rba); err != nil { if err = s.bwdb.SaveOrder(ctx, rba); err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") || if strings.Contains(err.Error(), "UNIQUE constraint failed") ||
strings.Contains(err.Error(), "violates unique constraint") { strings.Contains(err.Error(), "violates unique constraint") {
return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err)) return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err))

View File

@ -90,8 +90,6 @@ func TestSequential(t *testing.T) {
} }
func TestParallel(t *testing.T) { func TestParallel(t *testing.T) {
t.Skip("logic is broken on database side")
satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) { satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
@ -101,7 +99,6 @@ func TestParallel(t *testing.T) {
errs := make(chan error, N*2) errs := make(chan error, N*2)
entries := make(chan *pb.InjuredSegment, N*2) entries := make(chan *pb.InjuredSegment, N*2)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(N) wg.Add(N)
// Add to queue concurrently // Add to queue concurrently
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
@ -115,7 +112,6 @@ func TestParallel(t *testing.T) {
errs <- err errs <- err
} }
}(i) }(i)
} }
wg.Wait() wg.Wait()

50
pkg/pb/scannerValuer.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package pb
import (
"database/sql/driver"
proto "github.com/gogo/protobuf/proto"
"github.com/zeebo/errs"
)
// Error is bootstrap web error type
var scanError = errs.Class("Protobuf Scanner")
var valueError = errs.Class("Protobuf Valuer")
//scan automatically converts database []byte to proto.Messages
func scan(msg proto.Message, value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return scanError.New("%t was %t, expected []bytes", msg, value)
}
return scanError.Wrap(proto.Unmarshal(bytes, msg))
}
//value automatically converts proto.Messages to database []byte
func value(msg proto.Message) (driver.Value, error) {
value, err := proto.Marshal(msg)
return value, valueError.Wrap(err)
}
// Scan implements the Scanner interface.
func (n *InjuredSegment) Scan(value interface{}) error {
return scan(n, value)
}
// Value implements the driver Valuer interface.
func (n InjuredSegment) Value() (driver.Value, error) {
return value(&n)
}
// Scan implements the Scanner interface.
func (n *Order) Scan(value interface{}) error {
return scan(n, value)
}
// Value implements the driver Valuer interface.
func (n Order) Value() (driver.Value, error) {
return value(&n)
}

View File

@ -177,7 +177,7 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
LEFT JOIN nodes n ON n.id = r.node_id LEFT JOIN nodes n ON n.id = r.node_id
LEFT JOIN overlay_cache_nodes o ON n.id = o.node_id LEFT JOIN overlay_cache_nodes o ON n.id = o.node_id
ORDER BY n.id` ORDER BY n.id`
rows, err := db.db.DB.Query(db.db.Rebind(sqlStmt), start.UTC(), end.UTC()) rows, err := db.db.DB.QueryContext(ctx, db.db.Rebind(sqlStmt), start.UTC(), end.UTC())
if err != nil { if err != nil {
return nil, Error.Wrap(err) return nil, Error.Wrap(err)
} }
@ -206,8 +206,8 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e
} }
// DeleteRawBefore deletes all raw tallies prior to some time // DeleteRawBefore deletes all raw tallies prior to some time
func (db *accountingDB) DeleteRawBefore(latestRollup time.Time) error { func (db *accountingDB) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
var deleteRawSQL = `DELETE FROM accounting_raws WHERE interval_end_time < ?` var deleteRawSQL = `DELETE FROM accounting_raws WHERE interval_end_time < ?`
_, err := db.db.DB.Exec(db.db.Rebind(deleteRawSQL), latestRollup) _, err := db.db.DB.ExecContext(ctx, db.db.Rebind(deleteRawSQL), latestRollup)
return err return err
} }

View File

@ -20,9 +20,9 @@ type bandwidthagreement struct {
db *dbx.DB db *dbx.DB
} }
func (b *bandwidthagreement) SaveOrder(rba *pb.Order) (err error) { func (b *bandwidthagreement) SaveOrder(ctx context.Context, rba *pb.Order) (err error) {
var saveOrderSQL = `INSERT INTO bwagreements ( serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at ) VALUES ( ?, ?, ?, ?, ?, ?, ? )` var saveOrderSQL = `INSERT INTO bwagreements ( serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at ) VALUES ( ?, ?, ?, ?, ?, ?, ? )`
_, err = b.db.DB.Exec(b.db.Rebind(saveOrderSQL), _, err = b.db.DB.ExecContext(ctx, b.db.Rebind(saveOrderSQL),
rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String(), rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String(),
rba.StorageNodeId, rba.StorageNodeId,
rba.PayerAllocation.UplinkId, rba.PayerAllocation.UplinkId,
@ -43,7 +43,7 @@ func (b *bandwidthagreement) GetUplinkStats(ctx context.Context, from, to time.T
FROM bwagreements WHERE created_at > ? FROM bwagreements WHERE created_at > ?
AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`, AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`,
pb.BandwidthAction_PUT, pb.BandwidthAction_GET) pb.BandwidthAction_PUT, pb.BandwidthAction_GET)
rows, err := b.db.DB.Query(b.db.Rebind(uplinkSQL), from.UTC(), to.UTC()) rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(uplinkSQL), from.UTC(), to.UTC())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +71,7 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT, GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT,
pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT, pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT,
pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR) pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR)
rows, err := b.db.DB.Query(b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC()) rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -92,10 +92,10 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time)
} }
//GetExpired gets orders that are expired and were created before some time //GetExpired gets orders that are expired and were created before some time
func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) { func (b *bandwidthagreement) GetExpired(ctx context.Context, before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) {
var getExpiredSQL = `SELECT serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at var getExpiredSQL = `SELECT serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at
FROM bwagreements WHERE created_at < ? AND expires_at < ?` FROM bwagreements WHERE created_at < ? AND expires_at < ?`
rows, err := b.db.DB.Query(b.db.Rebind(getExpiredSQL), before, expiredAt) rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getExpiredSQL), before, expiredAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,8 +112,8 @@ func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (
} }
//DeleteExpired deletes orders that are expired and were created before some time //DeleteExpired deletes orders that are expired and were created before some time
func (b *bandwidthagreement) DeleteExpired(before time.Time, expiredAt time.Time) error { func (b *bandwidthagreement) DeleteExpired(ctx context.Context, before time.Time, expiredAt time.Time) error {
var deleteExpiredSQL = `DELETE FROM bwagreements WHERE created_at < ? AND expires_at < ?` var deleteExpiredSQL = `DELETE FROM bwagreements WHERE created_at < ? AND expires_at < ?`
_, err := b.db.DB.Exec(b.db.Rebind(deleteExpiredSQL), before, expiredAt) _, err := b.db.DB.ExecContext(ctx, b.db.Rebind(deleteExpiredSQL), before, expiredAt)
return err return err
} }

View File

@ -51,10 +51,10 @@ type lockedAccounting struct {
} }
// DeleteRawBefore deletes all raw tallies prior to some time // DeleteRawBefore deletes all raw tallies prior to some time
func (m *lockedAccounting) DeleteRawBefore(latestRollup time.Time) error { func (m *lockedAccounting) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.db.DeleteRawBefore(latestRollup) return m.db.DeleteRawBefore(ctx, latestRollup)
} }
// GetRaw retrieves all raw tallies // GetRaw retrieves all raw tallies
@ -120,17 +120,17 @@ type lockedBandwidthAgreement struct {
} }
// DeleteExpired deletes orders that are expired and were created before some time // DeleteExpired deletes orders that are expired and were created before some time
func (m *lockedBandwidthAgreement) DeleteExpired(a0 time.Time, a1 time.Time) error { func (m *lockedBandwidthAgreement) DeleteExpired(ctx context.Context, a1 time.Time, a2 time.Time) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.db.DeleteExpired(a0, a1) return m.db.DeleteExpired(ctx, a1, a2)
} }
// GetExpired gets orders that are expired and were created before some time // GetExpired gets orders that are expired and were created before some time
func (m *lockedBandwidthAgreement) GetExpired(a0 time.Time, a1 time.Time) ([]bwagreement.SavedOrder, error) { func (m *lockedBandwidthAgreement) GetExpired(ctx context.Context, a1 time.Time, a2 time.Time) ([]bwagreement.SavedOrder, error) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.db.GetExpired(a0, a1) return m.db.GetExpired(ctx, a1, a2)
} }
// GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range // GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range
@ -148,10 +148,10 @@ func (m *lockedBandwidthAgreement) GetUplinkStats(ctx context.Context, a1 time.T
} }
// SaveOrder saves an order for accounting // SaveOrder saves an order for accounting
func (m *lockedBandwidthAgreement) SaveOrder(a0 *pb.RenterBandwidthAllocation) error { func (m *lockedBandwidthAgreement) SaveOrder(ctx context.Context, a1 *pb.RenterBandwidthAllocation) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.db.SaveOrder(a0) return m.db.SaveOrder(ctx, a1)
} }
// CertDB returns database for storing uplink's public key & ID // CertDB returns database for storing uplink's public key & ID

View File

@ -5,11 +5,14 @@ package satellitedb
import ( import (
"context" "context"
"database/sql"
"fmt"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/utils"
dbx "storj.io/storj/satellite/satellitedb/dbx" dbx "storj.io/storj/satellite/satellitedb/dbx"
"storj.io/storj/storage" "storj.io/storj/storage"
) )
@ -31,38 +34,53 @@ func (r *repairQueue) Enqueue(ctx context.Context, seg *pb.InjuredSegment) error
return err return err
} }
func (r *repairQueue) Dequeue(ctx context.Context) (pb.InjuredSegment, error) { func (r *repairQueue) postgresDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
// TODO: fix out of order issue err = r.db.DB.QueryRowContext(ctx, `
tx, err := r.db.Open(ctx) DELETE FROM injuredsegments
if err != nil { WHERE id = ( SELECT id FROM injuredsegments ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1 )
return pb.InjuredSegment{}, Error.Wrap(err) RETURNING info
`).Scan(&seg)
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
} }
return seg, err
}
res, err := tx.First_Injuredsegment(ctx) func (r *repairQueue) sqliteDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
if err != nil { err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback())) var id int64
} else if res == nil { err = tx.Tx.QueryRowContext(ctx, `SELECT id, info FROM injuredsegments ORDER BY id LIMIT 1`).Scan(&id, &seg)
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(storage.ErrEmptyQueue.New(""), tx.Rollback())) if err != nil {
return err
}
res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE id = ?`), id)
if err != nil {
return err
}
count, err := res.RowsAffected()
if err != nil {
return err
}
if count != 1 {
return fmt.Errorf("Expected 1, got %d segments deleted", count)
}
return nil
})
if err == sql.ErrNoRows {
err = storage.ErrEmptyQueue.New("")
} }
return seg, err
}
deleted, err := tx.Delete_Injuredsegment_By_Id( func (r *repairQueue) Dequeue(ctx context.Context) (seg pb.InjuredSegment, err error) {
ctx, switch t := r.db.DB.Driver().(type) {
dbx.Injuredsegment_Id(res.Id), case *sqlite3.SQLiteDriver:
) return r.sqliteDequeue(ctx)
if err != nil { case *pq.Driver:
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback())) return r.postgresDequeue(ctx)
} else if !deleted { default:
return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(Error.New("Injured segment not deleted"), tx.Rollback())) return seg, fmt.Errorf("Unsupported database %t", t)
} }
if err := tx.Commit(); err != nil {
return pb.InjuredSegment{}, Error.Wrap(err)
}
seg := &pb.InjuredSegment{}
if err = proto.Unmarshal(res.Info, seg); err != nil {
return pb.InjuredSegment{}, Error.Wrap(err)
}
return *seg, nil
} }
func (r *repairQueue) Peekqueue(ctx context.Context, limit int) ([]pb.InjuredSegment, error) { func (r *repairQueue) Peekqueue(ctx context.Context, limit int) ([]pb.InjuredSegment, error) {