diff --git a/pkg/accounting/db.go b/pkg/accounting/db.go index a606fa472..285d81299 100644 --- a/pkg/accounting/db.go +++ b/pkg/accounting/db.go @@ -53,5 +53,5 @@ type DB interface { // QueryPaymentInfo queries StatDB, Accounting Rollup on nodeID QueryPaymentInfo(ctx context.Context, start time.Time, end time.Time) ([]*CSVRow, error) // DeleteRawBefore deletes all raw tallies prior to some time - DeleteRawBefore(latestRollup time.Time) error + DeleteRawBefore(ctx context.Context, latestRollup time.Time) error } diff --git a/pkg/accounting/rollup/rollup.go b/pkg/accounting/rollup/rollup.go index 7a1d94146..e153361b0 100644 --- a/pkg/accounting/rollup/rollup.go +++ b/pkg/accounting/rollup/rollup.go @@ -117,7 +117,7 @@ func (r *Rollup) RollupRaws(ctx context.Context) error { var rolledUpRawsHaveBeenSaved bool //todo: write files to disk or whatever we decide to do here if rolledUpRawsHaveBeenSaved { - return Error.Wrap(r.db.DeleteRawBefore(latestTally)) + return Error.Wrap(r.db.DeleteRawBefore(ctx, latestTally)) } return nil } diff --git a/pkg/accounting/tally/tally.go b/pkg/accounting/tally/tally.go index 54b86024c..f0ed06e0a 100644 --- a/pkg/accounting/tally/tally.go +++ b/pkg/accounting/tally/tally.go @@ -88,14 +88,14 @@ func (t *Tally) Tally(ctx context.Context) error { } else { //remove expired records now := time.Now() - _, err = t.bwAgreementDB.GetExpired(tallyEnd, now) + _, err = t.bwAgreementDB.GetExpired(ctx, tallyEnd, now) if err != nil { return err } var expiredOrdersHaveBeenSaved bool //todo: write files to disk or whatever we decide to do here if expiredOrdersHaveBeenSaved { - err = t.bwAgreementDB.DeleteExpired(tallyEnd, now) + err = t.bwAgreementDB.DeleteExpired(ctx, tallyEnd, now) if err != nil { return err } diff --git a/pkg/bwagreement/bwagreement_test.go b/pkg/bwagreement/bwagreement_test.go index fe88ce9fa..a2cf597b1 100644 --- a/pkg/bwagreement/bwagreement_test.go +++ b/pkg/bwagreement/bwagreement_test.go @@ -29,15 +29,15 @@ func TestBandwidthDBAgreement(t *testing.T) { snID, err := testidentity.NewTestIdentity(ctx) require.NoError(t, err) - require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID)) - require.Error(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID)) - require.NoError(t, testSaveOrder(t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID)) + require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_PUT, "1", upID, snID)) + require.Error(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "1", upID, snID)) + require.NoError(t, testSaveOrder(ctx, t, db.BandwidthAgreement(), pb.BandwidthAction_GET, "2", upID, snID)) testGetTotals(ctx, t, db.BandwidthAgreement(), snID) testGetUplinkStats(ctx, t, db.BandwidthAgreement(), upID) }) } -func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction, +func testSaveOrder(ctx context.Context, t *testing.T, b bwagreement.DB, action pb.BandwidthAction, serialNum string, upID, snID *identity.FullIdentity) error { rba := &pb.Order{ PayerAllocation: pb.OrderLimit{ @@ -48,7 +48,7 @@ func testSaveOrder(t *testing.T, b bwagreement.DB, action pb.BandwidthAction, Total: 1000, StorageNodeId: snID.ID, } - return b.SaveOrder(rba) + return b.SaveOrder(ctx, rba) } func testGetUplinkStats(ctx context.Context, t *testing.T, b bwagreement.DB, upID *identity.FullIdentity) { diff --git a/pkg/bwagreement/server.go b/pkg/bwagreement/server.go index 971a8fe3b..a3e8c802f 100644 --- a/pkg/bwagreement/server.go +++ b/pkg/bwagreement/server.go @@ -56,15 +56,15 @@ type SavedOrder struct { // DB stores orders for accounting purposes type DB interface { // SaveOrder saves an order for accounting - SaveOrder(*pb.Order) error + SaveOrder(context.Context, *pb.Order) error // GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range GetTotals(context.Context, time.Time, time.Time) (map[storj.NodeID][]int64, error) //GetTotals returns stats about an uplink GetUplinkStats(context.Context, time.Time, time.Time) ([]UplinkStat, error) //GetExpired gets orders that are expired and were created before some time - GetExpired(time.Time, time.Time) ([]SavedOrder, error) + GetExpired(context.Context, time.Time, time.Time) ([]SavedOrder, error) //DeleteExpired deletes orders that are expired and were created before some time - DeleteExpired(time.Time, time.Time) error + DeleteExpired(context.Context, time.Time, time.Time) error } // Server is an implementation of the pb.BandwidthServer interface @@ -112,7 +112,7 @@ func (s *Server) BandwidthAgreements(ctx context.Context, rba *pb.Order) (reply } //save and return rersults - if err = s.bwdb.SaveOrder(rba); err != nil { + if err = s.bwdb.SaveOrder(ctx, rba); err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "violates unique constraint") { return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err)) diff --git a/pkg/datarepair/queue/queue_test.go b/pkg/datarepair/queue/queue_test.go index b4843c705..79a67a89b 100644 --- a/pkg/datarepair/queue/queue_test.go +++ b/pkg/datarepair/queue/queue_test.go @@ -90,8 +90,6 @@ func TestSequential(t *testing.T) { } func TestParallel(t *testing.T) { - t.Skip("logic is broken on database side") - satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) { ctx := testcontext.New(t) defer ctx.Cleanup() @@ -101,7 +99,6 @@ func TestParallel(t *testing.T) { errs := make(chan error, N*2) entries := make(chan *pb.InjuredSegment, N*2) var wg sync.WaitGroup - wg.Add(N) // Add to queue concurrently for i := 0; i < N; i++ { @@ -115,7 +112,6 @@ func TestParallel(t *testing.T) { errs <- err } }(i) - } wg.Wait() diff --git a/pkg/pb/scannerValuer.go b/pkg/pb/scannerValuer.go new file mode 100644 index 000000000..7b5c83b45 --- /dev/null +++ b/pkg/pb/scannerValuer.go @@ -0,0 +1,50 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package pb + +import ( + "database/sql/driver" + + proto "github.com/gogo/protobuf/proto" + "github.com/zeebo/errs" +) + +// Error is bootstrap web error type +var scanError = errs.Class("Protobuf Scanner") +var valueError = errs.Class("Protobuf Valuer") + +//scan automatically converts database []byte to proto.Messages +func scan(msg proto.Message, value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return scanError.New("%t was %t, expected []bytes", msg, value) + } + return scanError.Wrap(proto.Unmarshal(bytes, msg)) +} + +//value automatically converts proto.Messages to database []byte +func value(msg proto.Message) (driver.Value, error) { + value, err := proto.Marshal(msg) + return value, valueError.Wrap(err) +} + +// Scan implements the Scanner interface. +func (n *InjuredSegment) Scan(value interface{}) error { + return scan(n, value) +} + +// Value implements the driver Valuer interface. +func (n InjuredSegment) Value() (driver.Value, error) { + return value(&n) +} + +// Scan implements the Scanner interface. +func (n *Order) Scan(value interface{}) error { + return scan(n, value) +} + +// Value implements the driver Valuer interface. +func (n Order) Value() (driver.Value, error) { + return value(&n) +} diff --git a/satellite/satellitedb/accounting.go b/satellite/satellitedb/accounting.go index 766798262..5d1a15500 100644 --- a/satellite/satellitedb/accounting.go +++ b/satellite/satellitedb/accounting.go @@ -177,7 +177,7 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e LEFT JOIN nodes n ON n.id = r.node_id LEFT JOIN overlay_cache_nodes o ON n.id = o.node_id ORDER BY n.id` - rows, err := db.db.DB.Query(db.db.Rebind(sqlStmt), start.UTC(), end.UTC()) + rows, err := db.db.DB.QueryContext(ctx, db.db.Rebind(sqlStmt), start.UTC(), end.UTC()) if err != nil { return nil, Error.Wrap(err) } @@ -206,8 +206,8 @@ func (db *accountingDB) QueryPaymentInfo(ctx context.Context, start time.Time, e } // DeleteRawBefore deletes all raw tallies prior to some time -func (db *accountingDB) DeleteRawBefore(latestRollup time.Time) error { +func (db *accountingDB) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error { var deleteRawSQL = `DELETE FROM accounting_raws WHERE interval_end_time < ?` - _, err := db.db.DB.Exec(db.db.Rebind(deleteRawSQL), latestRollup) + _, err := db.db.DB.ExecContext(ctx, db.db.Rebind(deleteRawSQL), latestRollup) return err } diff --git a/satellite/satellitedb/bandwidthagreement.go b/satellite/satellitedb/bandwidthagreement.go index 01e942e14..41c77e228 100644 --- a/satellite/satellitedb/bandwidthagreement.go +++ b/satellite/satellitedb/bandwidthagreement.go @@ -20,9 +20,9 @@ type bandwidthagreement struct { db *dbx.DB } -func (b *bandwidthagreement) SaveOrder(rba *pb.Order) (err error) { +func (b *bandwidthagreement) SaveOrder(ctx context.Context, rba *pb.Order) (err error) { var saveOrderSQL = `INSERT INTO bwagreements ( serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at ) VALUES ( ?, ?, ?, ?, ?, ?, ? )` - _, err = b.db.DB.Exec(b.db.Rebind(saveOrderSQL), + _, err = b.db.DB.ExecContext(ctx, b.db.Rebind(saveOrderSQL), rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String(), rba.StorageNodeId, rba.PayerAllocation.UplinkId, @@ -43,7 +43,7 @@ func (b *bandwidthagreement) GetUplinkStats(ctx context.Context, from, to time.T FROM bwagreements WHERE created_at > ? AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`, pb.BandwidthAction_PUT, pb.BandwidthAction_GET) - rows, err := b.db.DB.Query(b.db.Rebind(uplinkSQL), from.UTC(), to.UTC()) + rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(uplinkSQL), from.UTC(), to.UTC()) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time) GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT, pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT, pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR) - rows, err := b.db.DB.Query(b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC()) + rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getTotalsSQL), from.UTC(), to.UTC()) if err != nil { return nil, err } @@ -92,10 +92,10 @@ func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time) } //GetExpired gets orders that are expired and were created before some time -func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) { +func (b *bandwidthagreement) GetExpired(ctx context.Context, before time.Time, expiredAt time.Time) (orders []bwagreement.SavedOrder, err error) { var getExpiredSQL = `SELECT serialnum, storage_node_id, uplink_id, action, total, created_at, expires_at FROM bwagreements WHERE created_at < ? AND expires_at < ?` - rows, err := b.db.DB.Query(b.db.Rebind(getExpiredSQL), before, expiredAt) + rows, err := b.db.DB.QueryContext(ctx, b.db.Rebind(getExpiredSQL), before, expiredAt) if err != nil { return nil, err } @@ -112,8 +112,8 @@ func (b *bandwidthagreement) GetExpired(before time.Time, expiredAt time.Time) ( } //DeleteExpired deletes orders that are expired and were created before some time -func (b *bandwidthagreement) DeleteExpired(before time.Time, expiredAt time.Time) error { +func (b *bandwidthagreement) DeleteExpired(ctx context.Context, before time.Time, expiredAt time.Time) error { var deleteExpiredSQL = `DELETE FROM bwagreements WHERE created_at < ? AND expires_at < ?` - _, err := b.db.DB.Exec(b.db.Rebind(deleteExpiredSQL), before, expiredAt) + _, err := b.db.DB.ExecContext(ctx, b.db.Rebind(deleteExpiredSQL), before, expiredAt) return err } diff --git a/satellite/satellitedb/locked.go b/satellite/satellitedb/locked.go index 9a1e451b6..5bc4e2424 100644 --- a/satellite/satellitedb/locked.go +++ b/satellite/satellitedb/locked.go @@ -51,10 +51,10 @@ type lockedAccounting struct { } // DeleteRawBefore deletes all raw tallies prior to some time -func (m *lockedAccounting) DeleteRawBefore(latestRollup time.Time) error { +func (m *lockedAccounting) DeleteRawBefore(ctx context.Context, latestRollup time.Time) error { m.Lock() defer m.Unlock() - return m.db.DeleteRawBefore(latestRollup) + return m.db.DeleteRawBefore(ctx, latestRollup) } // GetRaw retrieves all raw tallies @@ -120,17 +120,17 @@ type lockedBandwidthAgreement struct { } // DeleteExpired deletes orders that are expired and were created before some time -func (m *lockedBandwidthAgreement) DeleteExpired(a0 time.Time, a1 time.Time) error { +func (m *lockedBandwidthAgreement) DeleteExpired(ctx context.Context, a1 time.Time, a2 time.Time) error { m.Lock() defer m.Unlock() - return m.db.DeleteExpired(a0, a1) + return m.db.DeleteExpired(ctx, a1, a2) } // GetExpired gets orders that are expired and were created before some time -func (m *lockedBandwidthAgreement) GetExpired(a0 time.Time, a1 time.Time) ([]bwagreement.SavedOrder, error) { +func (m *lockedBandwidthAgreement) GetExpired(ctx context.Context, a1 time.Time, a2 time.Time) ([]bwagreement.SavedOrder, error) { m.Lock() defer m.Unlock() - return m.db.GetExpired(a0, a1) + return m.db.GetExpired(ctx, a1, a2) } // GetTotalsSince returns the sum of each bandwidth type after (exluding) a given date range @@ -148,10 +148,10 @@ func (m *lockedBandwidthAgreement) GetUplinkStats(ctx context.Context, a1 time.T } // SaveOrder saves an order for accounting -func (m *lockedBandwidthAgreement) SaveOrder(a0 *pb.RenterBandwidthAllocation) error { +func (m *lockedBandwidthAgreement) SaveOrder(ctx context.Context, a1 *pb.RenterBandwidthAllocation) error { m.Lock() defer m.Unlock() - return m.db.SaveOrder(a0) + return m.db.SaveOrder(ctx, a1) } // CertDB returns database for storing uplink's public key & ID diff --git a/satellite/satellitedb/repairqueue.go b/satellite/satellitedb/repairqueue.go index a50dd10b4..6e71d0235 100644 --- a/satellite/satellitedb/repairqueue.go +++ b/satellite/satellitedb/repairqueue.go @@ -5,11 +5,14 @@ package satellitedb import ( "context" + "database/sql" + "fmt" "github.com/golang/protobuf/proto" + "github.com/lib/pq" + sqlite3 "github.com/mattn/go-sqlite3" "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/utils" dbx "storj.io/storj/satellite/satellitedb/dbx" "storj.io/storj/storage" ) @@ -31,38 +34,53 @@ func (r *repairQueue) Enqueue(ctx context.Context, seg *pb.InjuredSegment) error return err } -func (r *repairQueue) Dequeue(ctx context.Context) (pb.InjuredSegment, error) { - // TODO: fix out of order issue - tx, err := r.db.Open(ctx) - if err != nil { - return pb.InjuredSegment{}, Error.Wrap(err) +func (r *repairQueue) postgresDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) { + err = r.db.DB.QueryRowContext(ctx, ` + DELETE FROM injuredsegments + WHERE id = ( SELECT id FROM injuredsegments ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1 ) + RETURNING info + `).Scan(&seg) + if err == sql.ErrNoRows { + err = storage.ErrEmptyQueue.New("") } + return seg, err +} - res, err := tx.First_Injuredsegment(ctx) - if err != nil { - return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback())) - } else if res == nil { - return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(storage.ErrEmptyQueue.New(""), tx.Rollback())) +func (r *repairQueue) sqliteDequeue(ctx context.Context) (seg pb.InjuredSegment, err error) { + err = r.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error { + var id int64 + err = tx.Tx.QueryRowContext(ctx, `SELECT id, info FROM injuredsegments ORDER BY id LIMIT 1`).Scan(&id, &seg) + if err != nil { + return err + } + res, err := tx.Tx.ExecContext(ctx, r.db.Rebind(`DELETE FROM injuredsegments WHERE id = ?`), id) + if err != nil { + return err + } + count, err := res.RowsAffected() + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected 1, got %d segments deleted", count) + } + return nil + }) + if err == sql.ErrNoRows { + err = storage.ErrEmptyQueue.New("") } + return seg, err +} - deleted, err := tx.Delete_Injuredsegment_By_Id( - ctx, - dbx.Injuredsegment_Id(res.Id), - ) - if err != nil { - return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(err, tx.Rollback())) - } else if !deleted { - return pb.InjuredSegment{}, Error.Wrap(utils.CombineErrors(Error.New("Injured segment not deleted"), tx.Rollback())) +func (r *repairQueue) Dequeue(ctx context.Context) (seg pb.InjuredSegment, err error) { + switch t := r.db.DB.Driver().(type) { + case *sqlite3.SQLiteDriver: + return r.sqliteDequeue(ctx) + case *pq.Driver: + return r.postgresDequeue(ctx) + default: + return seg, fmt.Errorf("Unsupported database %t", t) } - if err := tx.Commit(); err != nil { - return pb.InjuredSegment{}, Error.Wrap(err) - } - - seg := &pb.InjuredSegment{} - if err = proto.Unmarshal(res.Info, seg); err != nil { - return pb.InjuredSegment{}, Error.Wrap(err) - } - return *seg, nil } func (r *repairQueue) Peekqueue(ctx context.Context, limit int) ([]pb.InjuredSegment, error) {