diff --git a/pkg/accounting/db_test.go b/pkg/accounting/db_test.go index ae32233d4..9fd384934 100644 --- a/pkg/accounting/db_test.go +++ b/pkg/accounting/db_test.go @@ -45,15 +45,13 @@ func createBucketStorageTallies(projectID uuid.UUID) (map[string]*accounting.Buc var expectedTallies []accounting.BucketTally for i := 0; i < 4; i++ { - bucketName := fmt.Sprintf("%s%d", "testbucket", i) bucketID := storj.JoinPaths(projectID.String(), bucketName) - bucketIDComponents := storj.SplitPath(bucketID) // Setup: The data in this tally should match the pointer that the uplink.upload created tally := accounting.BucketTally{ - BucketName: []byte(bucketIDComponents[1]), - ProjectID: []byte(bucketIDComponents[0]), + BucketName: []byte(bucketName), + ProjectID: projectID[:], InlineSegments: int64(1), RemoteSegments: int64(1), Files: int64(1), diff --git a/satellite/satellitedb/orders.go b/satellite/satellitedb/orders.go index 2e3539a1b..4653e2931 100644 --- a/satellite/satellitedb/orders.go +++ b/satellite/satellitedb/orders.go @@ -4,7 +4,6 @@ package satellitedb import ( - "bytes" "context" "database/sql" "sort" @@ -70,8 +69,10 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri // UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { defer mon.Task()(&ctx)(&err) - pathElements := bytes.Split(bucketID, []byte("/")) - bucketName, projectID := pathElements[1], pathElements[0] + projectID, bucketName, err := splitBucketID(bucketID) + if err != nil { + return err + } statement := db.db.Rebind( `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) VALUES (?, ?, ?, ?, ?, ?, ?, ?) @@ -79,7 +80,7 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`, ) _, err = db.db.ExecContext(ctx, statement, - bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount), + bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount), ) if err != nil { return err @@ -91,8 +92,10 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI // UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { defer mon.Task()(&ctx)(&err) - pathElements := bytes.Split(bucketID, []byte("/")) - bucketName, projectID := pathElements[1], pathElements[0] + projectID, bucketName, err := splitBucketID(bucketID) + if err != nil { + return err + } statement := db.db.Rebind( `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) VALUES (?, ?, ?, ?, ?, ?, ?, ?) @@ -100,7 +103,7 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID [] DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`, ) _, err = db.db.ExecContext(ctx, statement, - bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount), + bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount), ) if err != nil { return err @@ -111,8 +114,10 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID [] // UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { defer mon.Task()(&ctx)(&err) - pathElements := bytes.Split(bucketID, []byte("/")) - bucketName, projectID := pathElements[1], pathElements[0] + projectID, bucketName, err := splitBucketID(bucketID) + if err != nil { + return err + } statement := db.db.Rebind( `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) VALUES (?, ?, ?, ?, ?, ?, ?, ?) @@ -120,7 +125,7 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID [] DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`, ) _, err = db.db.ExecContext(ctx, statement, - bucketName, projectID, intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount), + bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount), ) if err != nil { return err @@ -191,11 +196,11 @@ func (db *ordersDB) UpdateStoragenodeBandwidthSettle(ctx context.Context, storag // GetBucketBandwidth gets total bucket bandwidth from period of time func (db *ordersDB) GetBucketBandwidth(ctx context.Context, bucketID []byte, from, to time.Time) (_ int64, err error) { defer mon.Task()(&ctx)(&err) - pathElements := bytes.Split(bucketID, []byte("/")) - bucketName, projectID := pathElements[1], pathElements[0] + projectID, bucketName, err := splitBucketID(bucketID) + var sum *int64 query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?` - err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID, from, to).Scan(&sum) + err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum) if err == sql.ErrNoRows || sum == nil { return 0, nil } diff --git a/satellite/satellitedb/projectaccounting.go b/satellite/satellitedb/projectaccounting.go index 171344883..e2dd92be3 100644 --- a/satellite/satellitedb/projectaccounting.go +++ b/satellite/satellitedb/projectaccounting.go @@ -4,7 +4,6 @@ package satellitedb import ( - "bytes" "context" "database/sql" "time" @@ -14,7 +13,6 @@ import ( "storj.io/storj/internal/memory" "storj.io/storj/pkg/accounting" "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/storj" dbx "storj.io/storj/satellite/satellitedb/dbx" ) @@ -33,9 +31,12 @@ func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time var result []accounting.BucketTally for bucketID, info := range bucketTallies { - bucketIDComponents := storj.SplitPath(bucketID) - bucketName := dbx.BucketStorageTally_BucketName([]byte(bucketIDComponents[1])) - projectID := dbx.BucketStorageTally_ProjectId([]byte(bucketIDComponents[0])) + pid, bn, err := splitBucketID([]byte(bucketID)) + if err != nil { + return nil, err + } + bucketName := dbx.BucketStorageTally_BucketName(bn) + projectID := dbx.BucketStorageTally_ProjectId(pid[:]) interval := dbx.BucketStorageTally_IntervalStart(intervalStart) inlineBytes := dbx.BucketStorageTally_Inline(uint64(info.InlineBytes)) remoteBytes := dbx.BucketStorageTally_Remote(uint64(info.RemoteBytes)) @@ -86,11 +87,13 @@ func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accou // GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, bucketID []byte, from time.Time) (_ int64, err error) { defer mon.Task()(&ctx)(&err) - pathEl := bytes.Split(bucketID, []byte("/")) - _, projectID := pathEl[1], pathEl[0] + projectID, _, err := splitBucketID(bucketID) + if err != nil { + return 0, err + } var sum *int64 query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;` - err = db.db.QueryRow(db.db.Rebind(query), projectID, pb.PieceAction_GET, from).Scan(&sum) + err = db.db.QueryRow(db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum) if err == sql.ErrNoRows || sum == nil { return 0, nil } diff --git a/satellite/satellitedb/usagerollups.go b/satellite/satellitedb/usagerollups.go index 7f1ec61a3..16d1a6d7a 100644 --- a/satellite/satellitedb/usagerollups.go +++ b/satellite/satellitedb/usagerollups.go @@ -34,7 +34,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID WHERE project_id = ? AND interval_start >= ? AND interval_start <= ? GROUP BY action`) - rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), since, before) + rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], since, before) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID bucketsTallies := make(map[string]*[]*dbx.BucketStorageTally) for _, bucket := range buckets { storageTallies, err := storageQuery(ctx, - dbx.BucketStorageTally_ProjectId([]byte(projectID.String())), + dbx.BucketStorageTally_ProjectId(projectID[:]), dbx.BucketStorageTally_BucketName([]byte(bucket)), dbx.BucketStorageTally_IntervalStart(since), dbx.BucketStorageTally_IntervalStart(before)) @@ -124,7 +124,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui } // get bucket_bandwidth_rollups - rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before) + rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before) if err != nil { return nil, err } @@ -153,7 +153,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui } bucketStorageTallies, err := storageQuery(ctx, - dbx.BucketStorageTally_ProjectId([]byte(projectID.String())), + dbx.BucketStorageTally_ProjectId(projectID[:]), dbx.BucketStorageTally_BucketName([]byte(bucket)), dbx.BucketStorageTally_IntervalStart(since), dbx.BucketStorageTally_IntervalStart(before)) @@ -210,7 +210,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID countRow := db.db.QueryRowContext(ctx, countQuery, - []byte(projectID.String()), + projectID[:], since, before, search) @@ -234,7 +234,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, - []byte(projectID.String()), + projectID[:], since, before, search, page.Limit, @@ -277,7 +277,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID } // get bucket_bandwidth_rollups - rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before) + rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before) if err != nil { return nil, err } @@ -301,7 +301,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID bucketUsage.Egress = memory.Size(totalEgress).GB() - storageRow := db.db.QueryRowContext(ctx, storageQuery, []byte(projectID.String()), []byte(bucket), since, before) + storageRow := db.db.QueryRowContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func (db *usagerollups) getBuckets(ctx context.Context, projectID uuid.UUID, sin FROM bucket_bandwidth_rollups WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`) - bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, []byte(projectID.String()), since, before) + bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:], since, before) if err != nil { return nil, err } diff --git a/satellite/satellitedb/utils.go b/satellite/satellitedb/utils.go index 187237339..71c491ea0 100644 --- a/satellite/satellitedb/utils.go +++ b/satellite/satellitedb/utils.go @@ -4,6 +4,7 @@ package satellitedb import ( + "bytes" "database/sql/driver" "github.com/skyrings/skyring-common/tools/uuid" @@ -24,6 +25,17 @@ func bytesToUUID(data []byte) (uuid.UUID, error) { return id, nil } +// splitBucketID takes a bucketID, splits on /, and returns a projectID and bucketName +func splitBucketID(bucketID []byte) (projectID *uuid.UUID, bucketName []byte, err error) { + pathElements := bytes.Split(bucketID, []byte("/")) + bucketName = pathElements[1] + projectID, err = uuid.Parse(string(pathElements[0])) + if err != nil { + return nil, nil, err + } + return projectID, bucketName, nil +} + type postgresNodeIDList storj.NodeIDList // Value converts a NodeIDList to a postgres array diff --git a/satellite/satellitedb/utils_test.go b/satellite/satellitedb/utils_test.go index ab137dc9f..4b61962c4 100644 --- a/satellite/satellitedb/utils_test.go +++ b/satellite/satellitedb/utils_test.go @@ -36,6 +36,32 @@ func TestBytesToUUID(t *testing.T) { }) } +func TestSpliteBucketID(t *testing.T) { + t.Run("Invalid input", func(t *testing.T) { + str := "not UUID string/bucket1" + bytes := []byte(str) + + _, _, err := splitBucketID(bytes) + + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("Valid input", func(t *testing.T) { + expectedBucketID, err := uuid.Parse("bb6218e3-4b4a-4819-abbb-fa68538e33c0") + expectedBucketName := "bucket1" + assert.NoError(t, err) + + str := expectedBucketID.String() + "/" + expectedBucketName + + bucketID, bucketName, err := splitBucketID([]byte(str)) + + assert.NoError(t, err) + assert.Equal(t, bucketID, expectedBucketID) + assert.Equal(t, bucketName, []byte(expectedBucketName)) + }) +} + func TestPostgresNodeIDsArray(t *testing.T) { ids := make(storj.NodeIDList, 10) for i := range ids {