satellite/satellitedb: ensure that we process orders in order (#2950)
When transactions are handled in different orders there is a potential for a deadlock.
This commit is contained in:
parent
c35ad5cbfc
commit
a3e0955e16
@ -110,12 +110,12 @@ func (id NodeID) IsZero() bool {
|
||||
// Bytes returns raw bytes of the id
|
||||
func (id NodeID) Bytes() []byte { return id[:] }
|
||||
|
||||
// Less returns whether id is smaller than b in lexiographic order
|
||||
func (id NodeID) Less(b NodeID) bool {
|
||||
// Less returns whether id is smaller than other in lexicographic order.
|
||||
func (id NodeID) Less(other NodeID) bool {
|
||||
for k, v := range id {
|
||||
if v < b[k] {
|
||||
if v < other[k] {
|
||||
return true
|
||||
} else if v > b[k] {
|
||||
} else if v > other[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -44,6 +44,18 @@ func (id SerialNumber) IsZero() bool {
|
||||
return id == SerialNumber{}
|
||||
}
|
||||
|
||||
// Less returns whether id is smaller than other in lexicographic order.
|
||||
func (id SerialNumber) Less(other SerialNumber) bool {
|
||||
for k, v := range id {
|
||||
if v < other[k] {
|
||||
return true
|
||||
} else if v > other[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// String representation of the serial number
|
||||
func (id SerialNumber) String() string { return serialNumberEncoding.EncodeToString(id.Bytes()) }
|
||||
|
||||
|
@ -4,12 +4,10 @@
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -27,6 +25,11 @@ import (
|
||||
|
||||
const defaultIntervalSeconds = int(time.Hour / time.Second)
|
||||
|
||||
var (
|
||||
// ErrDifferentStorageNodes is returned when ProcessOrders gets orders from different storage nodes.
|
||||
ErrDifferentStorageNodes = errs.Class("different storage nodes")
|
||||
)
|
||||
|
||||
type ordersDB struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
@ -229,7 +232,9 @@ func (db *ordersDB) UnuseSerialNumber(ctx context.Context, serialNumber storj.Se
|
||||
return err
|
||||
}
|
||||
|
||||
// ProcessOrders take a list of order requests and "settles" them in one transaction
|
||||
// ProcessOrders take a list of order requests and "settles" them in one transaction.
|
||||
//
|
||||
// ProcessOrders requires that all orders come from the same storage node.
|
||||
func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.ProcessOrderRequest) (responses []*orders.ProcessOrderResponse, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
@ -237,6 +242,19 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check that all requests are from the same storage node
|
||||
storageNodeID := requests[0].OrderLimit.StorageNodeId
|
||||
for _, req := range requests[1:] {
|
||||
if req.OrderLimit.StorageNodeId != storageNodeID {
|
||||
return nil, ErrDifferentStorageNodes.New("requests from different storage nodes %v and %v", storageNodeID, req.OrderLimit.StorageNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
// sort requests by serial number, all of them should be from the same storage node
|
||||
sort.Slice(requests, func(i, k int) bool {
|
||||
return requests[i].OrderLimit.SerialNumber.Less(requests[k].OrderLimit.SerialNumber)
|
||||
})
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -252,169 +270,142 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
|
||||
now := time.Now().UTC()
|
||||
intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
|
||||
rejectedRequests := make(map[storj.SerialNumber]bool)
|
||||
reject := func(serialNumber storj.SerialNumber) {
|
||||
r := &orders.ProcessOrderResponse{
|
||||
SerialNumber: serialNumber,
|
||||
Status: pb.SettlementResponse_REJECTED,
|
||||
}
|
||||
rejectedRequests[serialNumber] = true
|
||||
responses = append(responses, r)
|
||||
}
|
||||
rejected := make(map[storj.SerialNumber]bool)
|
||||
bucketBySerial := make(map[storj.SerialNumber][]byte)
|
||||
|
||||
// processes the insert to used serials table individually so we can handle
|
||||
// the case where the order has already been processed. Duplicates and previously
|
||||
// processed orders are rejected
|
||||
// load the bucket id and insert into used serials table
|
||||
for _, request := range requests {
|
||||
// avoid the PG error "current transaction is aborted, commands ignored until end of transaction block" if the below insert fails due any constraint.
|
||||
// see https://www.postgresql.org/message-id/13131805-BCBB-42DF-953B-27EE36AAF213%40yahoo.com
|
||||
_, err = tx.Exec("savepoint sp")
|
||||
row := tx.QueryRow(db.db.Rebind(`
|
||||
SELECT id, bucket_id
|
||||
FROM serial_numbers
|
||||
WHERE serial_number = ?
|
||||
`), request.OrderLimit.SerialNumber)
|
||||
|
||||
var serialNumberID int64
|
||||
var bucketID []byte
|
||||
if err := row.Scan(&serialNumberID, &bucketID); err != nil {
|
||||
rejected[request.OrderLimit.SerialNumber] = true
|
||||
continue
|
||||
}
|
||||
|
||||
var result sql.Result
|
||||
var count int64
|
||||
|
||||
// try to insert the serial number
|
||||
result, err = tx.Exec(db.db.Rebind(`
|
||||
INSERT INTO used_serials(serial_number_id, storage_node_id)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT DO NOTHING
|
||||
`), serialNumberID, storageNodeID)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
insert := "INSERT INTO used_serials (serial_number_id, storage_node_id) SELECT id, ? FROM serial_numbers WHERE serial_number = ?"
|
||||
|
||||
_, err = tx.Exec(db.db.Rebind(insert), request.OrderLimit.StorageNodeId.Bytes(), request.OrderLimit.SerialNumber.Bytes())
|
||||
// if we didn't update any rows, then it must already exist
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
|
||||
reject(request.OrderLimit.SerialNumber)
|
||||
// rollback to the savepoint before the insert failed
|
||||
_, err = tx.Exec("rollback to savepoint sp")
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
} else {
|
||||
_, rerr := tx.Exec("rollback to savepoint sp")
|
||||
return nil, errs.Combine(Error.Wrap(err), Error.Wrap(rerr))
|
||||
}
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
_, err = tx.Exec("release savepoint sp")
|
||||
if count == 0 {
|
||||
rejected[request.OrderLimit.SerialNumber] = true
|
||||
continue
|
||||
}
|
||||
|
||||
bucketBySerial[request.OrderLimit.SerialNumber] = bucketID
|
||||
}
|
||||
|
||||
// add up amount by action
|
||||
var largestAction pb.PieceAction
|
||||
amountByAction := map[pb.PieceAction]int64{}
|
||||
for _, request := range requests {
|
||||
if rejected[request.OrderLimit.SerialNumber] {
|
||||
continue
|
||||
}
|
||||
limit, order := request.OrderLimit, request.Order
|
||||
amountByAction[limit.Action] += order.Amount
|
||||
if largestAction < limit.Action {
|
||||
largestAction = limit.Action
|
||||
}
|
||||
}
|
||||
|
||||
// do action updates for storage node
|
||||
for action := pb.PieceAction(0); action <= largestAction; action++ {
|
||||
amount := amountByAction[action]
|
||||
if amount == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err := tx.Exec(db.db.Rebind(`
|
||||
INSERT INTO storagenode_bandwidth_rollups
|
||||
(storagenode_id, interval_start, interval_seconds, action, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (storagenode_id, interval_start, action)
|
||||
DO UPDATE SET settled = storagenode_bandwidth_rollups.settled + ?
|
||||
`), storageNodeID, intervalStart, defaultIntervalSeconds, action, 0, amount, amount)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
// call to get all the bucket IDs
|
||||
query := db.buildGetBucketIdsQuery(len(requests))
|
||||
statement := db.db.Rebind(query)
|
||||
|
||||
args := make([]interface{}, len(requests))
|
||||
for i, request := range requests {
|
||||
args[i] = request.OrderLimit.SerialNumber.Bytes()
|
||||
// sort bucket updates
|
||||
type bucketUpdate struct {
|
||||
bucketID []byte
|
||||
action pb.PieceAction
|
||||
amount int64
|
||||
}
|
||||
|
||||
rows, err := tx.Query(statement, args...)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
bucketMap := make(map[storj.SerialNumber][]byte)
|
||||
for rows.Next() {
|
||||
var serialNumber, bucketID []byte
|
||||
err := rows.Scan(&serialNumber, &bucketID)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
sn, err := storj.SerialNumberFromBytes(serialNumber)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
bucketMap[sn] = bucketID
|
||||
}
|
||||
|
||||
// build all the bandwidth updates into one sql statement
|
||||
var updateRollupStatement string
|
||||
var bucketUpdates []bucketUpdate
|
||||
for _, request := range requests {
|
||||
_, rejected := rejectedRequests[request.OrderLimit.SerialNumber]
|
||||
if rejected {
|
||||
if rejected[request.OrderLimit.SerialNumber] {
|
||||
continue
|
||||
}
|
||||
bucketID, ok := bucketMap[request.OrderLimit.SerialNumber]
|
||||
if !ok {
|
||||
reject(request.OrderLimit.SerialNumber)
|
||||
continue
|
||||
limit, order := request.OrderLimit, request.Order
|
||||
|
||||
bucketUpdates = append(bucketUpdates, bucketUpdate{
|
||||
bucketID: bucketBySerial[limit.SerialNumber],
|
||||
action: limit.Action,
|
||||
amount: order.Amount,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(bucketUpdates, func(i, k int) bool {
|
||||
compare := bytes.Compare(bucketUpdates[i].bucketID, bucketUpdates[k].bucketID)
|
||||
if compare == 0 {
|
||||
return bucketUpdates[i].action < bucketUpdates[k].action
|
||||
}
|
||||
projectID, bucketName, err := orders.SplitBucketID(bucketID)
|
||||
return compare < 0
|
||||
})
|
||||
|
||||
// do bucket updates
|
||||
for _, update := range bucketUpdates {
|
||||
projectID, bucketName, err := orders.SplitBucketID(update.bucketID)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
stmt, err := db.buildUpdateBucketBandwidthRollupStatements(request.OrderLimit, request.Order, projectID[:], bucketName, intervalStart)
|
||||
_, err = tx.Exec(db.db.Rebind(`
|
||||
INSERT INTO bucket_bandwidth_rollups
|
||||
(bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (bucket_name, project_id, interval_start, action)
|
||||
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?
|
||||
`), bucketName, (*projectID)[:], intervalStart, defaultIntervalSeconds, update.action, 0, 0, update.amount, update.amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
updateRollupStatement += stmt
|
||||
|
||||
stmt, err = db.buildUpdateStorageNodeBandwidthRollupStatements(request.OrderLimit, request.Order, intervalStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updateRollupStatement += stmt
|
||||
}
|
||||
|
||||
_, err = tx.Exec(updateRollupStatement)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
for _, request := range requests {
|
||||
_, rejected := rejectedRequests[request.OrderLimit.SerialNumber]
|
||||
if !rejected {
|
||||
r := &orders.ProcessOrderResponse{
|
||||
if !rejected[request.OrderLimit.SerialNumber] {
|
||||
responses = append(responses, &orders.ProcessOrderResponse{
|
||||
SerialNumber: request.OrderLimit.SerialNumber,
|
||||
Status: pb.SettlementResponse_ACCEPTED,
|
||||
}
|
||||
|
||||
responses = append(responses, r)
|
||||
})
|
||||
} else {
|
||||
responses = append(responses, &orders.ProcessOrderResponse{
|
||||
SerialNumber: request.OrderLimit.SerialNumber,
|
||||
Status: pb.SettlementResponse_REJECTED,
|
||||
})
|
||||
}
|
||||
}
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
func (db *ordersDB) buildGetBucketIdsQuery(argCount int) string {
|
||||
args := make([]string, argCount)
|
||||
for i := 0; i < argCount; i++ {
|
||||
args[i] = "?"
|
||||
}
|
||||
return fmt.Sprintf("SELECT serial_number, bucket_id FROM serial_numbers WHERE serial_number IN (%s);\n", strings.Join(args, ","))
|
||||
}
|
||||
|
||||
func (db *ordersDB) buildUpdateBucketBandwidthRollupStatements(orderLimit *pb.OrderLimit, order *pb.Order, projectID []byte, bucketName []byte, intervalStart time.Time) (string, error) {
|
||||
hexName, err := db.toHex(bucketName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hexProjectID, err := db.toHex(projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf(`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
|
||||
VALUES (%s, %s, '%s', %d, %d, %d, %d, %d)
|
||||
ON CONFLICT(bucket_name, project_id, interval_start, action)
|
||||
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + %d;
|
||||
`, hexName, hexProjectID, intervalStart.Format("2006-01-02 15:04:05+00:00"), defaultIntervalSeconds, orderLimit.Action, 0, 0, order.Amount, order.Amount), nil
|
||||
}
|
||||
|
||||
func (db *ordersDB) buildUpdateStorageNodeBandwidthRollupStatements(orderLimit *pb.OrderLimit, order *pb.Order, intervalStart time.Time) (string, error) {
|
||||
hexNodeID, err := db.toHex(orderLimit.StorageNodeId.Bytes())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, allocated, settled)
|
||||
VALUES (%s, '%s', %d, %d, %d, %d)
|
||||
ON CONFLICT(storagenode_id, interval_start, action)
|
||||
DO UPDATE SET settled = storagenode_bandwidth_rollups.settled + %d;
|
||||
`, hexNodeID, intervalStart.Format("2006-01-02 15:04:05+00:00"), defaultIntervalSeconds, orderLimit.Action, 0, order.Amount, order.Amount), nil
|
||||
}
|
||||
|
||||
func (db *ordersDB) toHex(value []byte) (string, error) {
|
||||
hexValue := hex.EncodeToString(value)
|
||||
switch t := db.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
return fmt.Sprintf("X'%v'", hexValue), nil
|
||||
case *pq.Driver:
|
||||
return fmt.Sprintf("decode('%v', 'hex')", hexValue), nil
|
||||
default:
|
||||
return "", errs.New("Unsupported DB type %q", t)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user