satellite/satellitedb: ensure that we process orders in order (#2950)

When transactions are handled in different orders there is a potential for a deadlock.
This commit is contained in:
Egon Elbre 2019-09-06 17:49:30 +03:00 committed by GitHub
parent c35ad5cbfc
commit a3e0955e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 141 deletions

View File

@ -110,12 +110,12 @@ func (id NodeID) IsZero() bool {
// Bytes returns raw bytes of the id
func (id NodeID) Bytes() []byte { return id[:] }
// Less returns whether id is smaller than b in lexiographic order
func (id NodeID) Less(b NodeID) bool {
// Less returns whether id is smaller than other in lexicographic order.
func (id NodeID) Less(other NodeID) bool {
for k, v := range id {
if v < b[k] {
if v < other[k] {
return true
} else if v > b[k] {
} else if v > other[k] {
return false
}
}

View File

@ -44,6 +44,18 @@ func (id SerialNumber) IsZero() bool {
return id == SerialNumber{}
}
// Less returns whether id is smaller than other in lexicographic order.
func (id SerialNumber) Less(other SerialNumber) bool {
for k, v := range id {
if v < other[k] {
return true
} else if v > other[k] {
return false
}
}
return false
}
// String representation of the serial number
func (id SerialNumber) String() string { return serialNumberEncoding.EncodeToString(id.Bytes()) }

View File

@ -4,12 +4,10 @@
package satellitedb
import (
"bytes"
"context"
"database/sql"
"encoding/hex"
"fmt"
"sort"
"strings"
"time"
"github.com/lib/pq"
@ -27,6 +25,11 @@ import (
const defaultIntervalSeconds = int(time.Hour / time.Second)
var (
// ErrDifferentStorageNodes is returned when ProcessOrders gets orders from different storage nodes.
ErrDifferentStorageNodes = errs.Class("different storage nodes")
)
type ordersDB struct {
db *dbx.DB
}
@ -229,7 +232,9 @@ func (db *ordersDB) UnuseSerialNumber(ctx context.Context, serialNumber storj.Se
return err
}
// ProcessOrders take a list of order requests and "settles" them in one transaction
// ProcessOrders take a list of order requests and "settles" them in one transaction.
//
// ProcessOrders requires that all orders come from the same storage node.
func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.ProcessOrderRequest) (responses []*orders.ProcessOrderResponse, err error) {
defer mon.Task()(&ctx)(&err)
@ -237,6 +242,19 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
return nil, err
}
// check that all requests are from the same storage node
storageNodeID := requests[0].OrderLimit.StorageNodeId
for _, req := range requests[1:] {
if req.OrderLimit.StorageNodeId != storageNodeID {
return nil, ErrDifferentStorageNodes.New("requests from different storage nodes %v and %v", storageNodeID, req.OrderLimit.StorageNodeId)
}
}
// sort requests by serial number, all of them should be from the same storage node
sort.Slice(requests, func(i, k int) bool {
return requests[i].OrderLimit.SerialNumber.Less(requests[k].OrderLimit.SerialNumber)
})
tx, err := db.db.Begin()
if err != nil {
return nil, errs.Wrap(err)
@ -252,169 +270,142 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
now := time.Now().UTC()
intervalStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
rejectedRequests := make(map[storj.SerialNumber]bool)
reject := func(serialNumber storj.SerialNumber) {
r := &orders.ProcessOrderResponse{
SerialNumber: serialNumber,
Status: pb.SettlementResponse_REJECTED,
}
rejectedRequests[serialNumber] = true
responses = append(responses, r)
}
rejected := make(map[storj.SerialNumber]bool)
bucketBySerial := make(map[storj.SerialNumber][]byte)
// processes the insert to used serials table individually so we can handle
// the case where the order has already been processed. Duplicates and previously
// processed orders are rejected
// load the bucket id and insert into used serials table
for _, request := range requests {
// avoid the PG error "current transaction is aborted, commands ignored until end of transaction block" if the below insert fails due any constraint.
// see https://www.postgresql.org/message-id/13131805-BCBB-42DF-953B-27EE36AAF213%40yahoo.com
_, err = tx.Exec("savepoint sp")
if err != nil {
return nil, Error.Wrap(err)
}
row := tx.QueryRow(db.db.Rebind(`
SELECT id, bucket_id
FROM serial_numbers
WHERE serial_number = ?
`), request.OrderLimit.SerialNumber)
insert := "INSERT INTO used_serials (serial_number_id, storage_node_id) SELECT id, ? FROM serial_numbers WHERE serial_number = ?"
_, err = tx.Exec(db.db.Rebind(insert), request.OrderLimit.StorageNodeId.Bytes(), request.OrderLimit.SerialNumber.Bytes())
if err != nil {
if pgutil.IsConstraintError(err) || sqliteutil.IsConstraintError(err) {
reject(request.OrderLimit.SerialNumber)
// rollback to the savepoint before the insert failed
_, err = tx.Exec("rollback to savepoint sp")
if err != nil {
return nil, Error.Wrap(err)
}
} else {
_, rerr := tx.Exec("rollback to savepoint sp")
return nil, errs.Combine(Error.Wrap(err), Error.Wrap(rerr))
}
}
_, err = tx.Exec("release savepoint sp")
if err != nil {
return nil, Error.Wrap(err)
}
}
// call to get all the bucket IDs
query := db.buildGetBucketIdsQuery(len(requests))
statement := db.db.Rebind(query)
args := make([]interface{}, len(requests))
for i, request := range requests {
args[i] = request.OrderLimit.SerialNumber.Bytes()
}
rows, err := tx.Query(statement, args...)
if err != nil {
return nil, errs.Wrap(err)
}
bucketMap := make(map[storj.SerialNumber][]byte)
for rows.Next() {
var serialNumber, bucketID []byte
err := rows.Scan(&serialNumber, &bucketID)
if err != nil {
return nil, errs.Wrap(err)
}
sn, err := storj.SerialNumberFromBytes(serialNumber)
if err != nil {
return nil, errs.Wrap(err)
}
bucketMap[sn] = bucketID
}
// build all the bandwidth updates into one sql statement
var updateRollupStatement string
for _, request := range requests {
_, rejected := rejectedRequests[request.OrderLimit.SerialNumber]
if rejected {
var serialNumberID int64
var bucketID []byte
if err := row.Scan(&serialNumberID, &bucketID); err != nil {
rejected[request.OrderLimit.SerialNumber] = true
continue
}
bucketID, ok := bucketMap[request.OrderLimit.SerialNumber]
if !ok {
reject(request.OrderLimit.SerialNumber)
var result sql.Result
var count int64
// try to insert the serial number
result, err = tx.Exec(db.db.Rebind(`
INSERT INTO used_serials(serial_number_id, storage_node_id)
VALUES (?, ?)
ON CONFLICT DO NOTHING
`), serialNumberID, storageNodeID)
if err != nil {
return nil, Error.Wrap(err)
}
// if we didn't update any rows, then it must already exist
count, err = result.RowsAffected()
if err != nil {
return nil, Error.Wrap(err)
}
if count == 0 {
rejected[request.OrderLimit.SerialNumber] = true
continue
}
projectID, bucketName, err := orders.SplitBucketID(bucketID)
if err != nil {
return nil, errs.Wrap(err)
bucketBySerial[request.OrderLimit.SerialNumber] = bucketID
}
stmt, err := db.buildUpdateBucketBandwidthRollupStatements(request.OrderLimit, request.Order, projectID[:], bucketName, intervalStart)
if err != nil {
return nil, err
}
updateRollupStatement += stmt
stmt, err = db.buildUpdateStorageNodeBandwidthRollupStatements(request.OrderLimit, request.Order, intervalStart)
if err != nil {
return nil, err
}
updateRollupStatement += stmt
}
_, err = tx.Exec(updateRollupStatement)
if err != nil {
return nil, errs.Wrap(err)
}
// add up amount by action
var largestAction pb.PieceAction
amountByAction := map[pb.PieceAction]int64{}
for _, request := range requests {
_, rejected := rejectedRequests[request.OrderLimit.SerialNumber]
if !rejected {
r := &orders.ProcessOrderResponse{
if rejected[request.OrderLimit.SerialNumber] {
continue
}
limit, order := request.OrderLimit, request.Order
amountByAction[limit.Action] += order.Amount
if largestAction < limit.Action {
largestAction = limit.Action
}
}
// do action updates for storage node
for action := pb.PieceAction(0); action <= largestAction; action++ {
amount := amountByAction[action]
if amount == 0 {
continue
}
_, err := tx.Exec(db.db.Rebind(`
INSERT INTO storagenode_bandwidth_rollups
(storagenode_id, interval_start, interval_seconds, action, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (storagenode_id, interval_start, action)
DO UPDATE SET settled = storagenode_bandwidth_rollups.settled + ?
`), storageNodeID, intervalStart, defaultIntervalSeconds, action, 0, amount, amount)
if err != nil {
return nil, Error.Wrap(err)
}
}
// sort bucket updates
type bucketUpdate struct {
bucketID []byte
action pb.PieceAction
amount int64
}
var bucketUpdates []bucketUpdate
for _, request := range requests {
if rejected[request.OrderLimit.SerialNumber] {
continue
}
limit, order := request.OrderLimit, request.Order
bucketUpdates = append(bucketUpdates, bucketUpdate{
bucketID: bucketBySerial[limit.SerialNumber],
action: limit.Action,
amount: order.Amount,
})
}
sort.Slice(bucketUpdates, func(i, k int) bool {
compare := bytes.Compare(bucketUpdates[i].bucketID, bucketUpdates[k].bucketID)
if compare == 0 {
return bucketUpdates[i].action < bucketUpdates[k].action
}
return compare < 0
})
// do bucket updates
for _, update := range bucketUpdates {
projectID, bucketName, err := orders.SplitBucketID(update.bucketID)
if err != nil {
return nil, errs.Wrap(err)
}
_, err = tx.Exec(db.db.Rebind(`
INSERT INTO bucket_bandwidth_rollups
(bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (bucket_name, project_id, interval_start, action)
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?
`), bucketName, (*projectID)[:], intervalStart, defaultIntervalSeconds, update.action, 0, 0, update.amount, update.amount)
if err != nil {
return nil, Error.Wrap(err)
}
}
for _, request := range requests {
if !rejected[request.OrderLimit.SerialNumber] {
responses = append(responses, &orders.ProcessOrderResponse{
SerialNumber: request.OrderLimit.SerialNumber,
Status: pb.SettlementResponse_ACCEPTED,
}
responses = append(responses, r)
})
} else {
responses = append(responses, &orders.ProcessOrderResponse{
SerialNumber: request.OrderLimit.SerialNumber,
Status: pb.SettlementResponse_REJECTED,
})
}
}
return responses, nil
}
func (db *ordersDB) buildGetBucketIdsQuery(argCount int) string {
args := make([]string, argCount)
for i := 0; i < argCount; i++ {
args[i] = "?"
}
return fmt.Sprintf("SELECT serial_number, bucket_id FROM serial_numbers WHERE serial_number IN (%s);\n", strings.Join(args, ","))
}
func (db *ordersDB) buildUpdateBucketBandwidthRollupStatements(orderLimit *pb.OrderLimit, order *pb.Order, projectID []byte, bucketName []byte, intervalStart time.Time) (string, error) {
hexName, err := db.toHex(bucketName)
if err != nil {
return "", err
}
hexProjectID, err := db.toHex(projectID)
if err != nil {
return "", err
}
return fmt.Sprintf(`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (%s, %s, '%s', %d, %d, %d, %d, %d)
ON CONFLICT(bucket_name, project_id, interval_start, action)
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + %d;
`, hexName, hexProjectID, intervalStart.Format("2006-01-02 15:04:05+00:00"), defaultIntervalSeconds, orderLimit.Action, 0, 0, order.Amount, order.Amount), nil
}
func (db *ordersDB) buildUpdateStorageNodeBandwidthRollupStatements(orderLimit *pb.OrderLimit, order *pb.Order, intervalStart time.Time) (string, error) {
hexNodeID, err := db.toHex(orderLimit.StorageNodeId.Bytes())
if err != nil {
return "", err
}
return fmt.Sprintf(`INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, allocated, settled)
VALUES (%s, '%s', %d, %d, %d, %d)
ON CONFLICT(storagenode_id, interval_start, action)
DO UPDATE SET settled = storagenode_bandwidth_rollups.settled + %d;
`, hexNodeID, intervalStart.Format("2006-01-02 15:04:05+00:00"), defaultIntervalSeconds, orderLimit.Action, 0, order.Amount, order.Amount), nil
}
func (db *ordersDB) toHex(value []byte) (string, error) {
hexValue := hex.EncodeToString(value)
switch t := db.db.Driver().(type) {
case *sqlite3.SQLiteDriver:
return fmt.Sprintf("X'%v'", hexValue), nil
case *pq.Driver:
return fmt.Sprintf("decode('%v', 'hex')", hexValue), nil
default:
return "", errs.New("Unsupported DB type %q", t)
}
}