storj/satellite/satellitedb/orders.go
paul cannon a0a94a9ac7 satellite/satellitedb: insert into reported_serials w/ arrays
Change-Id: Icb682de09ded3e3159e3590594dcf13f2e7f40f0
2020-01-24 18:36:21 -06:00

575 lines
19 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/lib/pq"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/dbutil/pgutil"
"storj.io/storj/satellite/orders"
"storj.io/storj/satellite/satellitedb/dbx"
)
const defaultIntervalSeconds = int(time.Hour / time.Second)
var (
// ErrDifferentStorageNodes is returned when ProcessOrders gets orders from different storage nodes.
ErrDifferentStorageNodes = errs.Class("different storage nodes")
)
type ordersDB struct {
db *satelliteDB
reportedRollupsReadBatchSize int
}
// CreateSerialInfo creates serial number entry in database
func (db *ordersDB) CreateSerialInfo(ctx context.Context, serialNumber storj.SerialNumber, bucketID []byte, limitExpiration time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
return db.db.CreateNoReturn_SerialNumber(
ctx,
dbx.SerialNumber_SerialNumber(serialNumber.Bytes()),
dbx.SerialNumber_BucketId(bucketID),
dbx.SerialNumber_ExpiresAt(limitExpiration),
)
}
// DeleteExpiredSerials deletes all expired serials in serial_number and used_serials table.
func (db *ordersDB) DeleteExpiredSerials(ctx context.Context, now time.Time) (_ int, err error) {
defer mon.Task()(&ctx)(&err)
count, err := db.db.Delete_SerialNumber_By_ExpiresAt_LessOrEqual(ctx, dbx.SerialNumber_ExpiresAt(now))
if err != nil {
return 0, err
}
return int(count), nil
}
// UseSerialNumber creates serial number entry in database
func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.SerialNumber, storageNodeID storj.NodeID) (_ []byte, err error) {
defer mon.Task()(&ctx)(&err)
statement := db.db.Rebind(
`INSERT INTO used_serials (serial_number_id, storage_node_id)
SELECT id, ? FROM serial_numbers WHERE serial_number = ?`,
)
_, err = db.db.ExecContext(ctx, statement, storageNodeID.Bytes(), serialNumber.Bytes())
if err != nil {
if pgutil.IsConstraintError(err) {
return nil, orders.ErrUsingSerialNumber.New("serial number already used")
}
return nil, err
}
dbxSerialNumber, err := db.db.Find_SerialNumber_By_SerialNumber(
ctx,
dbx.SerialNumber_SerialNumber(serialNumber.Bytes()),
)
if err != nil {
return nil, err
}
if dbxSerialNumber == nil {
return nil, orders.ErrUsingSerialNumber.New("serial number not found")
}
return dbxSerialNumber.BucketId, nil
}
// UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, projectID uuid.UUID, bucketName []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(bucket_name, project_id, interval_start, action)
DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
)
if err != nil {
return err
}
return nil
}
// UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, projectID uuid.UUID, bucketName []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(bucket_name, project_id, interval_start, action)
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
)
if err != nil {
return err
}
return nil
}
// UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, projectID uuid.UUID, bucketName []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(bucket_name, project_id, interval_start, action)
DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
)
if err != nil {
return err
}
return nil
}
// UpdateStoragenodeBandwidthSettle updates 'settled' bandwidth for given storage node for the given intervalStart time
func (db *ordersDB) UpdateStoragenodeBandwidthSettle(ctx context.Context, storageNode storj.NodeID, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
statement := db.db.Rebind(
`INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, settled)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(storagenode_id, interval_start, action)
DO UPDATE SET settled = storagenode_bandwidth_rollups.settled + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
storageNode.Bytes(), intervalStart, defaultIntervalSeconds, action, uint64(amount), uint64(amount),
)
if err != nil {
return err
}
return nil
}
// GetBucketBandwidth gets total bucket bandwidth from period of time
func (db *ordersDB) GetBucketBandwidth(ctx context.Context, projectID uuid.UUID, bucketName []byte, from, to time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err)
var sum *int64
query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRow(ctx, db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
return *sum, Error.Wrap(err)
}
// GetStorageNodeBandwidth gets total storage node bandwidth from period of time
func (db *ordersDB) GetStorageNodeBandwidth(ctx context.Context, nodeID storj.NodeID, from, to time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err)
var sum *int64
query := `SELECT SUM(settled) FROM storagenode_bandwidth_rollups WHERE storagenode_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRow(ctx, db.db.Rebind(query), nodeID.Bytes(), from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
return *sum, err
}
// UnuseSerialNumber removes pair serial number -> storage node id from database
func (db *ordersDB) UnuseSerialNumber(ctx context.Context, serialNumber storj.SerialNumber, storageNodeID storj.NodeID) (err error) {
defer mon.Task()(&ctx)(&err)
statement := `DELETE FROM used_serials WHERE storage_node_id = ? AND
serial_number_id IN (SELECT id FROM serial_numbers WHERE serial_number = ?)`
_, err = db.db.ExecContext(ctx, db.db.Rebind(statement), storageNodeID.Bytes(), serialNumber.Bytes())
return err
}
// ProcessOrders take a list of order requests and "settles" them in one transaction.
//
// ProcessOrders requires that all orders come from the same storage node.
func (db *ordersDB) ProcessOrders(ctx context.Context, requests []*orders.ProcessOrderRequest, observedAt time.Time) (responses []*orders.ProcessOrderResponse, err error) {
defer mon.Task()(&ctx)(&err)
if len(requests) == 0 {
return nil, nil
}
// check that all requests are from the same storage node
storageNodeID := requests[0].OrderLimit.StorageNodeId
for _, req := range requests[1:] {
if req.OrderLimit.StorageNodeId != storageNodeID {
return nil, ErrDifferentStorageNodes.New("requests from different storage nodes %v and %v", storageNodeID, req.OrderLimit.StorageNodeId)
}
}
// Do a read first to get all the project id/bucket ids. We could combine this with the
// upsert below by doing a join, but there isn't really any need for special consistency
// semantics between these two queries, and it should make things easier on the database
// (particularly cockroachDB) to have the freedom to perform them separately.
//
// We don't expect the serial_number -> bucket_id relationship ever to change, as long as a
// serial_number exists. There is a possibility of a serial_number being deleted between
// this query and the next, but that is ok too (rows in reported_serials may end up having
// serial numbers that no longer exist in serial_numbers, but that shouldn't break
// anything.)
bucketIDs, err := func() (bucketIDs [][]byte, err error) {
bucketIDs = make([][]byte, len(requests))
serialNums := make([][]byte, len(requests))
for i, request := range requests {
serialNums[i] = request.Order.SerialNumber.Bytes()
}
rows, err := db.db.QueryContext(ctx, `
SELECT request.i, sn.bucket_id
FROM
serial_numbers sn,
unnest($1::bytea[]) WITH ORDINALITY AS request(serial_number, i)
WHERE request.serial_number = sn.serial_number
`, pq.ByteaArray(serialNums))
if err != nil {
return nil, Error.Wrap(err)
}
defer func() { err = errs.Combine(err, rows.Close(), rows.Err()) }()
for rows.Next() {
var index int
var bucketID []byte
err = rows.Scan(&index, &bucketID)
if err != nil {
return nil, Error.Wrap(err)
}
bucketIDs[index-1] = bucketID
}
return bucketIDs, nil
}()
if err != nil {
return nil, Error.Wrap(err)
}
// perform all of the upserts into reported serials table
expiresAtArray := make([]time.Time, 0, len(requests))
bucketIDArray := make([][]byte, 0, len(requests))
actionArray := make([]pb.PieceAction, 0, len(requests))
serialNumArray := make([][]byte, 0, len(requests))
settledArray := make([]int64, 0, len(requests))
for i, request := range requests {
if bucketIDs[i] == nil {
responses = append(responses, &orders.ProcessOrderResponse{
SerialNumber: request.Order.SerialNumber,
Status: pb.SettlementResponse_REJECTED,
})
continue
}
expiresAtArray = append(expiresAtArray, roundToNextDay(request.OrderLimit.OrderExpiration))
bucketIDArray = append(bucketIDArray, bucketIDs[i])
actionArray = append(actionArray, request.OrderLimit.Action)
serialNumCopy := request.Order.SerialNumber
serialNumArray = append(serialNumArray, serialNumCopy[:])
settledArray = append(settledArray, request.Order.Amount)
responses = append(responses, &orders.ProcessOrderResponse{
SerialNumber: request.Order.SerialNumber,
Status: pb.SettlementResponse_ACCEPTED,
})
}
var stmt string
switch db.db.implementation {
case dbutil.Postgres:
stmt = `
INSERT INTO reported_serials (
expires_at, storage_node_id, bucket_id, action, serial_number, settled, observed_at
)
SELECT unnest($1::timestamptz[]), $2::bytea, unnest($3::bytea[]), unnest($4::integer[]), unnest($5::bytea[]), unnest($6::bigint[]), $7::timestamptz
ON CONFLICT ( expires_at, storage_node_id, bucket_id, action, serial_number )
DO UPDATE SET
settled = EXCLUDED.settled,
observed_at = EXCLUDED.observed_at
`
case dbutil.Cockroach:
stmt = `
UPSERT INTO reported_serials (
expires_at, storage_node_id, bucket_id, action, serial_number, settled, observed_at
)
SELECT unnest($1::timestamptz[]), $2::bytea, unnest($3::bytea[]), unnest($4::integer[]), unnest($5::bytea[]), unnest($6::bigint[]), $7::timestamptz
`
default:
return nil, Error.New("invalid dbType: %v", db.db.driver)
}
_, err = db.db.ExecContext(ctx, stmt,
pq.Array(expiresAtArray),
storageNodeID.Bytes(),
pq.ByteaArray(bucketIDArray),
pq.Array(actionArray),
pq.ByteaArray(serialNumArray),
pq.Array(settledArray),
observedAt.UTC(),
)
if err != nil {
return nil, Error.Wrap(err)
}
return responses, nil
}
func roundToNextDay(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()).AddDate(0, 0, 1)
}
// GetBillableBandwidth gets total billable (expired consumed serial) bandwidth for nodes and buckets for all actions.
func (db *ordersDB) GetBillableBandwidth(ctx context.Context, now time.Time) (
bucketRollups []orders.BucketBandwidthRollup, storagenodeRollups []orders.StoragenodeBandwidthRollup, err error) {
defer mon.Task()(&ctx)(&err)
batchSize := db.reportedRollupsReadBatchSize
if batchSize <= 0 {
batchSize = 1000
}
type storagenodeKey struct {
nodeID storj.NodeID
action pb.PieceAction
}
byStoragenode := make(map[storagenodeKey]uint64)
type bucketKey struct {
projectID uuid.UUID
bucketName string
action pb.PieceAction
}
byBucket := make(map[bucketKey]uint64)
var token *dbx.Paged_ReportedSerial_By_ExpiresAt_LessOrEqual_Continuation
var rows []*dbx.ReportedSerial
for {
// We explicitly use a new transaction each time because we don't need the guarantees and
// because we don't want a transaction reading for 1000 years.
rows, token, err = db.db.Paged_ReportedSerial_By_ExpiresAt_LessOrEqual(ctx,
dbx.ReportedSerial_ExpiresAt(now), batchSize, token)
if err != nil {
return nil, nil, Error.Wrap(err)
}
for _, row := range rows {
nodeID, err := storj.NodeIDFromBytes(row.StorageNodeId)
if err != nil {
db.db.log.Error("bad row inserted into reported serials",
zap.Binary("storagenode_id", row.StorageNodeId))
continue
}
projectID, bucketName, err := orders.SplitBucketID(row.BucketId)
if err != nil {
db.db.log.Error("bad row inserted into reported serials",
zap.Binary("bucket_id", row.BucketId))
continue
}
action := pb.PieceAction(row.Action)
settled := row.Settled
byStoragenode[storagenodeKey{
nodeID: nodeID,
action: action,
}] += settled
byBucket[bucketKey{
projectID: *projectID,
bucketName: string(bucketName),
action: action,
}] += settled
}
if token == nil {
break
}
}
for key, settled := range byBucket {
bucketRollups = append(bucketRollups, orders.BucketBandwidthRollup{
ProjectID: key.projectID,
BucketName: key.bucketName,
Action: key.action,
Settled: int64(settled),
})
}
for key, settled := range byStoragenode {
storagenodeRollups = append(storagenodeRollups, orders.StoragenodeBandwidthRollup{
NodeID: key.nodeID,
Action: key.action,
Settled: int64(settled),
})
}
return bucketRollups, storagenodeRollups, nil
}
//
// transaction/batch methods
//
type ordersDBTx struct {
tx *dbx.Tx
log *zap.Logger
}
func (db *ordersDB) WithTransaction(ctx context.Context, cb func(ctx context.Context, tx orders.Transaction) error) (err error) {
defer mon.Task()(&ctx)(&err)
return db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
return cb(ctx, &ordersDBTx{tx: tx, log: db.db.log})
})
}
func (tx *ordersDBTx) UpdateBucketBandwidthBatch(ctx context.Context, intervalStart time.Time, rollups []orders.BucketBandwidthRollup) (err error) {
defer mon.Task()(&ctx)(&err)
if len(rollups) == 0 {
return nil
}
orders.SortBucketBandwidthRollups(rollups)
const stmtBegin = `
INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES
`
const stmtEnd = `
ON CONFLICT(bucket_name, project_id, interval_start, action)
DO UPDATE SET
allocated = bucket_bandwidth_rollups.allocated + EXCLUDED.allocated,
inline = bucket_bandwidth_rollups.inline + EXCLUDED.inline,
settled = bucket_bandwidth_rollups.settled + EXCLUDED.settled
`
intervalStart = intervalStart.UTC()
intervalStart = time.Date(intervalStart.Year(), intervalStart.Month(), intervalStart.Day(), intervalStart.Hour(), 0, 0, 0, time.UTC)
var lastProjectID uuid.UUID
var lastBucketName string
var projectIDArgNum int
var bucketNameArgNum int
var args []interface{}
var stmt strings.Builder
stmt.WriteString(stmtBegin)
args = append(args, intervalStart)
for i, rollup := range rollups {
if i > 0 {
stmt.WriteString(",")
}
if lastProjectID != rollup.ProjectID {
lastProjectID = rollup.ProjectID
// Take the slice over a copy of the value so that we don't mutate
// the underlying value for different range iterations. :grrcox:
project := rollup.ProjectID
args = append(args, project[:])
projectIDArgNum = len(args)
}
if lastBucketName != rollup.BucketName {
lastBucketName = rollup.BucketName
args = append(args, lastBucketName)
bucketNameArgNum = len(args)
}
args = append(args, rollup.Action, rollup.Inline, rollup.Allocated, rollup.Settled)
stmt.WriteString(fmt.Sprintf(
"($%d,$%d,$1,%d,$%d,$%d,$%d,$%d)",
bucketNameArgNum,
projectIDArgNum,
defaultIntervalSeconds,
len(args)-3,
len(args)-2,
len(args)-1,
len(args),
))
}
stmt.WriteString(stmtEnd)
_, err = tx.tx.Tx.ExecContext(ctx, stmt.String(), args...)
if err != nil {
tx.log.Error("Bucket bandwidth rollup batch flush failed.", zap.Error(err))
}
return err
}
func (tx *ordersDBTx) UpdateStoragenodeBandwidthBatch(ctx context.Context, intervalStart time.Time, rollups []orders.StoragenodeBandwidthRollup) (err error) {
defer mon.Task()(&ctx)(&err)
if len(rollups) == 0 {
return nil
}
orders.SortStoragenodeBandwidthRollups(rollups)
const stmtBegin = `
INSERT INTO storagenode_bandwidth_rollups (storagenode_id, interval_start, interval_seconds, action, allocated, settled)
VALUES
`
const stmtEnd = `
ON CONFLICT(storagenode_id, interval_start, action)
DO UPDATE SET
allocated = storagenode_bandwidth_rollups.allocated + EXCLUDED.allocated,
settled = storagenode_bandwidth_rollups.settled + EXCLUDED.settled
`
intervalStart = intervalStart.UTC()
intervalStart = time.Date(intervalStart.Year(), intervalStart.Month(), intervalStart.Day(), intervalStart.Hour(), 0, 0, 0, time.UTC)
var lastNodeID storj.NodeID
var nodeIDArgNum int
var args []interface{}
var stmt strings.Builder
stmt.WriteString(stmtBegin)
args = append(args, intervalStart)
for i, rollup := range rollups {
if i > 0 {
stmt.WriteString(",")
}
if lastNodeID != rollup.NodeID {
lastNodeID = rollup.NodeID
// take the slice over rollup.ProjectID, because it is going to stay
// the same up to the ExecContext call, whereas lastProjectID is likely
// to be overwritten
args = append(args, rollup.NodeID.Bytes())
nodeIDArgNum = len(args)
}
args = append(args, rollup.Action, rollup.Allocated, rollup.Settled)
stmt.WriteString(fmt.Sprintf(
"($%d,$1,%d,$%d,$%d,$%d)",
nodeIDArgNum,
defaultIntervalSeconds,
len(args)-2,
len(args)-1,
len(args),
))
}
stmt.WriteString(stmtEnd)
_, err = tx.tx.Tx.ExecContext(ctx, stmt.String(), args...)
if err != nil {
tx.log.Error("Storagenode bandwidth rollup batch flush failed.", zap.Error(err))
}
return err
}
// DeleteExpiredReportedSerials deletes any expired reported serials as of expiredThreshold.
func (tx *ordersDBTx) DeleteExpiredReportedSerials(ctx context.Context, expiredThreshold time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
_, err = tx.tx.Delete_ReportedSerial_By_ExpiresAt_LessOrEqual(ctx,
dbx.ReportedSerial_ExpiresAt(expiredThreshold))
return Error.Wrap(err)
}