Fix the way project_id is stored in bucket_storage_tallies and bucket_bandwidth_rollups (#2283)

* fixing issues where projectID is stored as the byte representation of a UUID string, instead of bytes of the UUID

* added test for spitBucketID
This commit is contained in:
ethanadams 2019-06-21 11:38:37 -04:00 committed by GitHub
parent bfcfe39313
commit 4f2e893e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 78 additions and 34 deletions

View File

@ -45,15 +45,13 @@ func createBucketStorageTallies(projectID uuid.UUID) (map[string]*accounting.Buc
var expectedTallies []accounting.BucketTally
for i := 0; i < 4; i++ {
bucketName := fmt.Sprintf("%s%d", "testbucket", i)
bucketID := storj.JoinPaths(projectID.String(), bucketName)
bucketIDComponents := storj.SplitPath(bucketID)
// Setup: The data in this tally should match the pointer that the uplink.upload created
tally := accounting.BucketTally{
BucketName: []byte(bucketIDComponents[1]),
ProjectID: []byte(bucketIDComponents[0]),
BucketName: []byte(bucketName),
ProjectID: projectID[:],
InlineSegments: int64(1),
RemoteSegments: int64(1),
Files: int64(1),

View File

@ -4,7 +4,6 @@
package satellitedb
import (
"bytes"
"context"
"database/sql"
"sort"
@ -70,8 +69,10 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri
// UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName, projectID := pathElements[1], pathElements[0]
projectID, bucketName, err := splitBucketID(bucketID)
if err != nil {
return err
}
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -79,7 +80,7 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
)
if err != nil {
return err
@ -91,8 +92,10 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
// UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName, projectID := pathElements[1], pathElements[0]
projectID, bucketName, err := splitBucketID(bucketID)
if err != nil {
return err
}
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -100,7 +103,7 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
)
if err != nil {
return err
@ -111,8 +114,10 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
// UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName, projectID := pathElements[1], pathElements[0]
projectID, bucketName, err := splitBucketID(bucketID)
if err != nil {
return err
}
statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -120,7 +125,7 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []
DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`,
)
_, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
)
if err != nil {
return err
@ -191,11 +196,11 @@ func (db *ordersDB) UpdateStoragenodeBandwidthSettle(ctx context.Context, storag
// GetBucketBandwidth gets total bucket bandwidth from period of time
func (db *ordersDB) GetBucketBandwidth(ctx context.Context, bucketID []byte, from, to time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName, projectID := pathElements[1], pathElements[0]
projectID, bucketName, err := splitBucketID(bucketID)
var sum *int64
query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID, from, to).Scan(&sum)
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}

View File

@ -4,7 +4,6 @@
package satellitedb
import (
"bytes"
"context"
"database/sql"
"time"
@ -14,7 +13,6 @@ import (
"storj.io/storj/internal/memory"
"storj.io/storj/pkg/accounting"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -33,9 +31,12 @@ func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time
var result []accounting.BucketTally
for bucketID, info := range bucketTallies {
bucketIDComponents := storj.SplitPath(bucketID)
bucketName := dbx.BucketStorageTally_BucketName([]byte(bucketIDComponents[1]))
projectID := dbx.BucketStorageTally_ProjectId([]byte(bucketIDComponents[0]))
pid, bn, err := splitBucketID([]byte(bucketID))
if err != nil {
return nil, err
}
bucketName := dbx.BucketStorageTally_BucketName(bn)
projectID := dbx.BucketStorageTally_ProjectId(pid[:])
interval := dbx.BucketStorageTally_IntervalStart(intervalStart)
inlineBytes := dbx.BucketStorageTally_Inline(uint64(info.InlineBytes))
remoteBytes := dbx.BucketStorageTally_Remote(uint64(info.RemoteBytes))
@ -86,11 +87,13 @@ func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accou
// GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame
func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, bucketID []byte, from time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err)
pathEl := bytes.Split(bucketID, []byte("/"))
_, projectID := pathEl[1], pathEl[0]
projectID, _, err := splitBucketID(bucketID)
if err != nil {
return 0, err
}
var sum *int64
query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;`
err = db.db.QueryRow(db.db.Rebind(query), projectID, pb.PieceAction_GET, from).Scan(&sum)
err = db.db.QueryRow(db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}

View File

@ -34,7 +34,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?
GROUP BY action`)
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), since, before)
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], since, before)
if err != nil {
return nil, err
}
@ -64,7 +64,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
bucketsTallies := make(map[string]*[]*dbx.BucketStorageTally)
for _, bucket := range buckets {
storageTallies, err := storageQuery(ctx,
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())),
dbx.BucketStorageTally_ProjectId(projectID[:]),
dbx.BucketStorageTally_BucketName([]byte(bucket)),
dbx.BucketStorageTally_IntervalStart(since),
dbx.BucketStorageTally_IntervalStart(before))
@ -124,7 +124,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
}
// get bucket_bandwidth_rollups
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before)
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
if err != nil {
return nil, err
}
@ -153,7 +153,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
}
bucketStorageTallies, err := storageQuery(ctx,
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())),
dbx.BucketStorageTally_ProjectId(projectID[:]),
dbx.BucketStorageTally_BucketName([]byte(bucket)),
dbx.BucketStorageTally_IntervalStart(since),
dbx.BucketStorageTally_IntervalStart(before))
@ -210,7 +210,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
countRow := db.db.QueryRowContext(ctx,
countQuery,
[]byte(projectID.String()),
projectID[:],
since, before,
search)
@ -234,7 +234,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
bucketRows, err := db.db.QueryContext(ctx,
bucketsQuery,
[]byte(projectID.String()),
projectID[:],
since, before,
search,
page.Limit,
@ -277,7 +277,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
}
// get bucket_bandwidth_rollups
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before)
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
if err != nil {
return nil, err
}
@ -301,7 +301,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
bucketUsage.Egress = memory.Size(totalEgress).GB()
storageRow := db.db.QueryRowContext(ctx, storageQuery, []byte(projectID.String()), []byte(bucket), since, before)
storageRow := db.db.QueryRowContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
if err != nil {
return nil, err
}
@ -338,7 +338,7 @@ func (db *usagerollups) getBuckets(ctx context.Context, projectID uuid.UUID, sin
FROM bucket_bandwidth_rollups
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`)
bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, []byte(projectID.String()), since, before)
bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:], since, before)
if err != nil {
return nil, err
}

View File

@ -4,6 +4,7 @@
package satellitedb
import (
"bytes"
"database/sql/driver"
"github.com/skyrings/skyring-common/tools/uuid"
@ -24,6 +25,17 @@ func bytesToUUID(data []byte) (uuid.UUID, error) {
return id, nil
}
// splitBucketID takes a bucketID, splits on /, and returns a projectID and bucketName
func splitBucketID(bucketID []byte) (projectID *uuid.UUID, bucketName []byte, err error) {
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName = pathElements[1]
projectID, err = uuid.Parse(string(pathElements[0]))
if err != nil {
return nil, nil, err
}
return projectID, bucketName, nil
}
type postgresNodeIDList storj.NodeIDList
// Value converts a NodeIDList to a postgres array

View File

@ -36,6 +36,32 @@ func TestBytesToUUID(t *testing.T) {
})
}
func TestSpliteBucketID(t *testing.T) {
t.Run("Invalid input", func(t *testing.T) {
str := "not UUID string/bucket1"
bytes := []byte(str)
_, _, err := splitBucketID(bytes)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("Valid input", func(t *testing.T) {
expectedBucketID, err := uuid.Parse("bb6218e3-4b4a-4819-abbb-fa68538e33c0")
expectedBucketName := "bucket1"
assert.NoError(t, err)
str := expectedBucketID.String() + "/" + expectedBucketName
bucketID, bucketName, err := splitBucketID([]byte(str))
assert.NoError(t, err)
assert.Equal(t, bucketID, expectedBucketID)
assert.Equal(t, bucketName, []byte(expectedBucketName))
})
}
func TestPostgresNodeIDsArray(t *testing.T) {
ids := make(storj.NodeIDList, 10)
for i := range ids {