satellite/satellitedb/orders: add multi row upserts to process orders
Change-Id: I00d8b55ee74b443fb328bd3a4378308cefa368e4
This commit is contained in:
parent
409d4123bb
commit
955abd9293
@ -281,3 +281,96 @@ func buildBenchmarkData(ctx context.Context, b *testing.B, db satellite.DB, stor
|
||||
}
|
||||
return requests
|
||||
}
|
||||
|
||||
func TestProcessOrders(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
ordersDB := db.Orders()
|
||||
serialNum := storj.SerialNumber{1}
|
||||
serialNum2 := storj.SerialNumber{2}
|
||||
projectID, _ := uuid.New()
|
||||
|
||||
// setup: create serial number records
|
||||
err := ordersDB.CreateSerialInfo(ctx, serialNum, []byte(projectID.String()+"/b"), time.Now().AddDate(0, 0, 1))
|
||||
require.NoError(t, err)
|
||||
err = ordersDB.CreateSerialInfo(ctx, serialNum2, []byte(projectID.String()+"/c"), time.Now().AddDate(0, 0, 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
requests := []*orders.ProcessOrderRequest{
|
||||
{
|
||||
Order: &pb.Order{
|
||||
SerialNumber: serialNum,
|
||||
Amount: 100,
|
||||
},
|
||||
OrderLimit: &pb.OrderLimit{
|
||||
SerialNumber: serialNum,
|
||||
StorageNodeId: storj.NodeID{1},
|
||||
Action: pb.PieceAction_DELETE,
|
||||
OrderExpiration: time.Now().AddDate(0, 0, 3),
|
||||
},
|
||||
},
|
||||
}
|
||||
// test: process one order and confirm we get the correct response
|
||||
actualResponses, err := ordersDB.ProcessOrders(ctx, requests)
|
||||
require.NoError(t, err)
|
||||
expectedResponses := []*orders.ProcessOrderResponse{
|
||||
{
|
||||
SerialNumber: serialNum,
|
||||
Status: pb.SettlementResponse_ACCEPTED,
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expectedResponses, actualResponses)
|
||||
|
||||
requests = append(requests, &orders.ProcessOrderRequest{
|
||||
Order: &pb.Order{
|
||||
SerialNumber: serialNum2,
|
||||
Amount: 200,
|
||||
},
|
||||
OrderLimit: &pb.OrderLimit{
|
||||
SerialNumber: serialNum2,
|
||||
StorageNodeId: storj.NodeID{2},
|
||||
Action: pb.PieceAction_PUT,
|
||||
OrderExpiration: time.Now().AddDate(0, 0, 1)},
|
||||
})
|
||||
// test: process two orders from different storagenodes and confirm there is an error
|
||||
_, err = ordersDB.ProcessOrders(ctx, requests)
|
||||
require.Error(t, err, "different storage nodes")
|
||||
|
||||
requests[0].OrderLimit.StorageNodeId = storj.NodeID{2}
|
||||
// test: process two orders from same storagenodes and confirm we get two responses
|
||||
actualResponses, err = ordersDB.ProcessOrders(ctx, requests)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(actualResponses))
|
||||
|
||||
// test: confirm the correct data from processing orders was written to reported_serials table
|
||||
bbr, snr, err := ordersDB.GetBillableBandwidth(ctx, time.Now().AddDate(0, 0, 3))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(bbr))
|
||||
expected := []orders.BucketBandwidthRollup{
|
||||
{
|
||||
ProjectID: *projectID,
|
||||
BucketName: "c",
|
||||
Action: pb.PieceAction_PUT,
|
||||
Inline: 0,
|
||||
Allocated: 0,
|
||||
Settled: 200,
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expected, bbr)
|
||||
assert.Equal(t, 1, len(snr))
|
||||
expectedRollup := []orders.StoragenodeBandwidthRollup{
|
||||
{
|
||||
NodeID: storj.NodeID{2},
|
||||
Action: pb.PieceAction_PUT,
|
||||
Allocated: 0,
|
||||
Settled: 200,
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expectedRollup, snr)
|
||||
bbr, snr, err = ordersDB.GetBillableBandwidth(ctx, time.Now().AddDate(0, 0, 5))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(bbr))
|
||||
assert.Equal(t, 3, len(snr))
|
||||
})
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/storj/private/dbutil"
|
||||
"storj.io/storj/private/dbutil/pgutil"
|
||||
"storj.io/storj/satellite/orders"
|
||||
"storj.io/storj/satellite/satellitedb/dbx"
|
||||
@ -234,9 +235,36 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
// perform all of the upserts into accounted serials table
|
||||
// perform all of the upserts into reported serials table
|
||||
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
now := time.Now()
|
||||
var stmt strings.Builder
|
||||
var stmtBegin, stmtEnd string
|
||||
switch db.db.implementation {
|
||||
case dbutil.Postgres:
|
||||
stmtBegin = `INSERT INTO reported_serials ( expires_at, storage_node_id, bucket_id, action, serial_number, settled, observed_at ) VALUES `
|
||||
stmtEnd = ` ON CONFLICT ( expires_at, storage_node_id, bucket_id, action, serial_number )
|
||||
DO UPDATE SET
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
storage_node_id = EXCLUDED.storage_node_id,
|
||||
bucket_id = EXCLUDED.bucket_id,
|
||||
action = EXCLUDED.action,
|
||||
serial_number = EXCLUDED.serial_number,
|
||||
settled = EXCLUDED.settled,
|
||||
observed_at = EXCLUDED.observed_at`
|
||||
case dbutil.Cockroach:
|
||||
stmtBegin = `UPSERT INTO reported_serials ( expires_at, storage_node_id, bucket_id, action, serial_number, settled, observed_at ) VALUES `
|
||||
default:
|
||||
return errs.New("invalid dbType: %v", db.db.driver)
|
||||
}
|
||||
|
||||
stmt.WriteString(stmtBegin)
|
||||
var expiresAt time.Time
|
||||
var bucketID []byte
|
||||
var serialNum storj.SerialNumber
|
||||
var action pb.PieceAction
|
||||
var expiresArgNum, bucketArgNum, serialArgNum, actionArgNum int
|
||||
var args []interface{}
|
||||
args = append(args, storageNodeID.Bytes(), time.Now().UTC())
|
||||
|
||||
for i, request := range requests {
|
||||
if bucketIDs[i] == nil {
|
||||
@ -247,24 +275,50 @@ func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.Proces
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: put them all in a single query?
|
||||
err = tx.ReplaceNoReturn_ReportedSerial(ctx,
|
||||
dbx.ReportedSerial_ExpiresAt(roundToNextDay(request.OrderLimit.OrderExpiration)),
|
||||
dbx.ReportedSerial_StorageNodeId(storageNodeID.Bytes()),
|
||||
dbx.ReportedSerial_BucketId(bucketIDs[i]),
|
||||
dbx.ReportedSerial_Action(uint(request.OrderLimit.Action)),
|
||||
dbx.ReportedSerial_SerialNumber(request.Order.SerialNumber.Bytes()),
|
||||
dbx.ReportedSerial_Settled(uint64(request.Order.Amount)),
|
||||
dbx.ReportedSerial_ObservedAt(now))
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
if i > 0 {
|
||||
stmt.WriteString(",")
|
||||
}
|
||||
if expiresAt != roundToNextDay(request.OrderLimit.OrderExpiration) {
|
||||
expiresAt = roundToNextDay(request.OrderLimit.OrderExpiration)
|
||||
args = append(args, expiresAt)
|
||||
expiresArgNum = len(args)
|
||||
}
|
||||
if string(bucketID) != string(bucketIDs[i]) {
|
||||
bucketID = bucketIDs[i]
|
||||
args = append(args, bucketID)
|
||||
bucketArgNum = len(args)
|
||||
}
|
||||
if action != request.OrderLimit.Action {
|
||||
action = request.OrderLimit.Action
|
||||
args = append(args, action)
|
||||
actionArgNum = len(args)
|
||||
}
|
||||
if serialNum != request.Order.SerialNumber {
|
||||
serialNum = request.Order.SerialNumber
|
||||
args = append(args, serialNum.Bytes())
|
||||
serialArgNum = len(args)
|
||||
}
|
||||
|
||||
args = append(args, request.Order.Amount)
|
||||
stmt.WriteString(fmt.Sprintf(
|
||||
"($%d,$1,$%d,$%d,$%d,$%d,$2)",
|
||||
expiresArgNum,
|
||||
bucketArgNum,
|
||||
actionArgNum,
|
||||
serialArgNum,
|
||||
len(args),
|
||||
))
|
||||
|
||||
responses = append(responses, &orders.ProcessOrderResponse{
|
||||
SerialNumber: request.Order.SerialNumber,
|
||||
Status: pb.SettlementResponse_ACCEPTED,
|
||||
})
|
||||
}
|
||||
stmt.WriteString(stmtEnd)
|
||||
_, err = tx.Tx.ExecContext(ctx, stmt.String(), args...)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user