Fix the way project_id is stored in bucket_storage_tallies and bucket_bandwidth_rollups (#2283)
* fixing issues where projectID is stored as the byte representation of a UUID string, instead of bytes of the UUID * added test for spitBucketID
This commit is contained in:
parent
bfcfe39313
commit
4f2e893e68
@ -45,15 +45,13 @@ func createBucketStorageTallies(projectID uuid.UUID) (map[string]*accounting.Buc
|
||||
var expectedTallies []accounting.BucketTally
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
|
||||
bucketName := fmt.Sprintf("%s%d", "testbucket", i)
|
||||
bucketID := storj.JoinPaths(projectID.String(), bucketName)
|
||||
bucketIDComponents := storj.SplitPath(bucketID)
|
||||
|
||||
// Setup: The data in this tally should match the pointer that the uplink.upload created
|
||||
tally := accounting.BucketTally{
|
||||
BucketName: []byte(bucketIDComponents[1]),
|
||||
ProjectID: []byte(bucketIDComponents[0]),
|
||||
BucketName: []byte(bucketName),
|
||||
ProjectID: projectID[:],
|
||||
InlineSegments: int64(1),
|
||||
RemoteSegments: int64(1),
|
||||
Files: int64(1),
|
||||
|
@ -4,7 +4,6 @@
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"sort"
|
||||
@ -70,8 +69,10 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri
|
||||
// UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket
|
||||
func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
pathElements := bytes.Split(bucketID, []byte("/"))
|
||||
bucketName, projectID := pathElements[1], pathElements[0]
|
||||
projectID, bucketName, err := splitBucketID(bucketID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statement := db.db.Rebind(
|
||||
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
@ -79,7 +80,7 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
|
||||
DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`,
|
||||
)
|
||||
_, err = db.db.ExecContext(ctx, statement,
|
||||
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
|
||||
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -91,8 +92,10 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
|
||||
// UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket
|
||||
func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
pathElements := bytes.Split(bucketID, []byte("/"))
|
||||
bucketName, projectID := pathElements[1], pathElements[0]
|
||||
projectID, bucketName, err := splitBucketID(bucketID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statement := db.db.Rebind(
|
||||
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
@ -100,7 +103,7 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
|
||||
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`,
|
||||
)
|
||||
_, err = db.db.ExecContext(ctx, statement,
|
||||
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
|
||||
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -111,8 +114,10 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
|
||||
// UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket
|
||||
func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
pathElements := bytes.Split(bucketID, []byte("/"))
|
||||
bucketName, projectID := pathElements[1], pathElements[0]
|
||||
projectID, bucketName, err := splitBucketID(bucketID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statement := db.db.Rebind(
|
||||
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
@ -120,7 +125,7 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []
|
||||
DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`,
|
||||
)
|
||||
_, err = db.db.ExecContext(ctx, statement,
|
||||
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
|
||||
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -191,11 +196,11 @@ func (db *ordersDB) UpdateStoragenodeBandwidthSettle(ctx context.Context, storag
|
||||
// GetBucketBandwidth gets total bucket bandwidth from period of time
|
||||
func (db *ordersDB) GetBucketBandwidth(ctx context.Context, bucketID []byte, from, to time.Time) (_ int64, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
pathElements := bytes.Split(bucketID, []byte("/"))
|
||||
bucketName, projectID := pathElements[1], pathElements[0]
|
||||
projectID, bucketName, err := splitBucketID(bucketID)
|
||||
|
||||
var sum *int64
|
||||
query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?`
|
||||
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID, from, to).Scan(&sum)
|
||||
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
|
||||
if err == sql.ErrNoRows || sum == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
@ -14,7 +13,6 @@ import (
|
||||
"storj.io/storj/internal/memory"
|
||||
"storj.io/storj/pkg/accounting"
|
||||
"storj.io/storj/pkg/pb"
|
||||
"storj.io/storj/pkg/storj"
|
||||
dbx "storj.io/storj/satellite/satellitedb/dbx"
|
||||
)
|
||||
|
||||
@ -33,9 +31,12 @@ func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time
|
||||
var result []accounting.BucketTally
|
||||
|
||||
for bucketID, info := range bucketTallies {
|
||||
bucketIDComponents := storj.SplitPath(bucketID)
|
||||
bucketName := dbx.BucketStorageTally_BucketName([]byte(bucketIDComponents[1]))
|
||||
projectID := dbx.BucketStorageTally_ProjectId([]byte(bucketIDComponents[0]))
|
||||
pid, bn, err := splitBucketID([]byte(bucketID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bucketName := dbx.BucketStorageTally_BucketName(bn)
|
||||
projectID := dbx.BucketStorageTally_ProjectId(pid[:])
|
||||
interval := dbx.BucketStorageTally_IntervalStart(intervalStart)
|
||||
inlineBytes := dbx.BucketStorageTally_Inline(uint64(info.InlineBytes))
|
||||
remoteBytes := dbx.BucketStorageTally_Remote(uint64(info.RemoteBytes))
|
||||
@ -86,11 +87,13 @@ func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accou
|
||||
// GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame
|
||||
func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, bucketID []byte, from time.Time) (_ int64, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
pathEl := bytes.Split(bucketID, []byte("/"))
|
||||
_, projectID := pathEl[1], pathEl[0]
|
||||
projectID, _, err := splitBucketID(bucketID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var sum *int64
|
||||
query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;`
|
||||
err = db.db.QueryRow(db.db.Rebind(query), projectID, pb.PieceAction_GET, from).Scan(&sum)
|
||||
err = db.db.QueryRow(db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum)
|
||||
if err == sql.ErrNoRows || sum == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
|
||||
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?
|
||||
GROUP BY action`)
|
||||
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), since, before)
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -64,7 +64,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
|
||||
bucketsTallies := make(map[string]*[]*dbx.BucketStorageTally)
|
||||
for _, bucket := range buckets {
|
||||
storageTallies, err := storageQuery(ctx,
|
||||
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())),
|
||||
dbx.BucketStorageTally_ProjectId(projectID[:]),
|
||||
dbx.BucketStorageTally_BucketName([]byte(bucket)),
|
||||
dbx.BucketStorageTally_IntervalStart(since),
|
||||
dbx.BucketStorageTally_IntervalStart(before))
|
||||
@ -124,7 +124,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
|
||||
}
|
||||
|
||||
// get bucket_bandwidth_rollups
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before)
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -153,7 +153,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
|
||||
}
|
||||
|
||||
bucketStorageTallies, err := storageQuery(ctx,
|
||||
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())),
|
||||
dbx.BucketStorageTally_ProjectId(projectID[:]),
|
||||
dbx.BucketStorageTally_BucketName([]byte(bucket)),
|
||||
dbx.BucketStorageTally_IntervalStart(since),
|
||||
dbx.BucketStorageTally_IntervalStart(before))
|
||||
@ -210,7 +210,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
|
||||
|
||||
countRow := db.db.QueryRowContext(ctx,
|
||||
countQuery,
|
||||
[]byte(projectID.String()),
|
||||
projectID[:],
|
||||
since, before,
|
||||
search)
|
||||
|
||||
@ -234,7 +234,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
|
||||
|
||||
bucketRows, err := db.db.QueryContext(ctx,
|
||||
bucketsQuery,
|
||||
[]byte(projectID.String()),
|
||||
projectID[:],
|
||||
since, before,
|
||||
search,
|
||||
page.Limit,
|
||||
@ -277,7 +277,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
|
||||
}
|
||||
|
||||
// get bucket_bandwidth_rollups
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before)
|
||||
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -301,7 +301,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
|
||||
|
||||
bucketUsage.Egress = memory.Size(totalEgress).GB()
|
||||
|
||||
storageRow := db.db.QueryRowContext(ctx, storageQuery, []byte(projectID.String()), []byte(bucket), since, before)
|
||||
storageRow := db.db.QueryRowContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -338,7 +338,7 @@ func (db *usagerollups) getBuckets(ctx context.Context, projectID uuid.UUID, sin
|
||||
FROM bucket_bandwidth_rollups
|
||||
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`)
|
||||
|
||||
bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, []byte(projectID.String()), since, before)
|
||||
bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:], since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
@ -24,6 +25,17 @@ func bytesToUUID(data []byte) (uuid.UUID, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// splitBucketID takes a bucketID, splits on /, and returns a projectID and bucketName
|
||||
func splitBucketID(bucketID []byte) (projectID *uuid.UUID, bucketName []byte, err error) {
|
||||
pathElements := bytes.Split(bucketID, []byte("/"))
|
||||
bucketName = pathElements[1]
|
||||
projectID, err = uuid.Parse(string(pathElements[0]))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return projectID, bucketName, nil
|
||||
}
|
||||
|
||||
type postgresNodeIDList storj.NodeIDList
|
||||
|
||||
// Value converts a NodeIDList to a postgres array
|
||||
|
@ -36,6 +36,32 @@ func TestBytesToUUID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSpliteBucketID(t *testing.T) {
|
||||
t.Run("Invalid input", func(t *testing.T) {
|
||||
str := "not UUID string/bucket1"
|
||||
bytes := []byte(str)
|
||||
|
||||
_, _, err := splitBucketID(bytes)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid input", func(t *testing.T) {
|
||||
expectedBucketID, err := uuid.Parse("bb6218e3-4b4a-4819-abbb-fa68538e33c0")
|
||||
expectedBucketName := "bucket1"
|
||||
assert.NoError(t, err)
|
||||
|
||||
str := expectedBucketID.String() + "/" + expectedBucketName
|
||||
|
||||
bucketID, bucketName, err := splitBucketID([]byte(str))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bucketID, expectedBucketID)
|
||||
assert.Equal(t, bucketName, []byte(expectedBucketName))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgresNodeIDsArray(t *testing.T) {
|
||||
ids := make(storj.NodeIDList, 10)
|
||||
for i := range ids {
|
||||
|
Loading…
Reference in New Issue
Block a user