Fix the way project_id is stored in bucket_storage_tallies and bucket_bandwidth_rollups (#2283)

* fixing issues where projectID is stored as the byte representation of a UUID string, instead of bytes of the UUID

* added test for spitBucketID
This commit is contained in:
ethanadams 2019-06-21 11:38:37 -04:00 committed by GitHub
parent bfcfe39313
commit 4f2e893e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 78 additions and 34 deletions

View File

@ -45,15 +45,13 @@ func createBucketStorageTallies(projectID uuid.UUID) (map[string]*accounting.Buc
var expectedTallies []accounting.BucketTally var expectedTallies []accounting.BucketTally
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
bucketName := fmt.Sprintf("%s%d", "testbucket", i) bucketName := fmt.Sprintf("%s%d", "testbucket", i)
bucketID := storj.JoinPaths(projectID.String(), bucketName) bucketID := storj.JoinPaths(projectID.String(), bucketName)
bucketIDComponents := storj.SplitPath(bucketID)
// Setup: The data in this tally should match the pointer that the uplink.upload created // Setup: The data in this tally should match the pointer that the uplink.upload created
tally := accounting.BucketTally{ tally := accounting.BucketTally{
BucketName: []byte(bucketIDComponents[1]), BucketName: []byte(bucketName),
ProjectID: []byte(bucketIDComponents[0]), ProjectID: projectID[:],
InlineSegments: int64(1), InlineSegments: int64(1),
RemoteSegments: int64(1), RemoteSegments: int64(1),
Files: int64(1), Files: int64(1),

View File

@ -4,7 +4,6 @@
package satellitedb package satellitedb
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"sort" "sort"
@ -70,8 +69,10 @@ func (db *ordersDB) UseSerialNumber(ctx context.Context, serialNumber storj.Seri
// UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket // UpdateBucketBandwidthAllocation updates 'allocated' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/")) projectID, bucketName, err := splitBucketID(bucketID)
bucketName, projectID := pathElements[1], pathElements[0] if err != nil {
return err
}
statement := db.db.Rebind( statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -79,7 +80,7 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`, DO UPDATE SET allocated = bucket_bandwidth_rollups.allocated + ?`,
) )
_, err = db.db.ExecContext(ctx, statement, _, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount), bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, uint64(amount), 0, uint64(amount),
) )
if err != nil { if err != nil {
return err return err
@ -91,8 +92,10 @@ func (db *ordersDB) UpdateBucketBandwidthAllocation(ctx context.Context, bucketI
// UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket // UpdateBucketBandwidthSettle updates 'settled' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/")) projectID, bucketName, err := splitBucketID(bucketID)
bucketName, projectID := pathElements[1], pathElements[0] if err != nil {
return err
}
statement := db.db.Rebind( statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -100,7 +103,7 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`, DO UPDATE SET settled = bucket_bandwidth_rollups.settled + ?`,
) )
_, err = db.db.ExecContext(ctx, statement, _, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount), bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, 0, 0, uint64(amount), uint64(amount),
) )
if err != nil { if err != nil {
return err return err
@ -111,8 +114,10 @@ func (db *ordersDB) UpdateBucketBandwidthSettle(ctx context.Context, bucketID []
// UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket // UpdateBucketBandwidthInline updates 'inline' bandwidth for given bucket
func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) { func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []byte, action pb.PieceAction, amount int64, intervalStart time.Time) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/")) projectID, bucketName, err := splitBucketID(bucketID)
bucketName, projectID := pathElements[1], pathElements[0] if err != nil {
return err
}
statement := db.db.Rebind( statement := db.db.Rebind(
`INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled) `INSERT INTO bucket_bandwidth_rollups (bucket_name, project_id, interval_start, interval_seconds, action, inline, allocated, settled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
@ -120,7 +125,7 @@ func (db *ordersDB) UpdateBucketBandwidthInline(ctx context.Context, bucketID []
DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`, DO UPDATE SET inline = bucket_bandwidth_rollups.inline + ?`,
) )
_, err = db.db.ExecContext(ctx, statement, _, err = db.db.ExecContext(ctx, statement,
bucketName, projectID, intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount), bucketName, projectID[:], intervalStart, defaultIntervalSeconds, action, uint64(amount), 0, 0, uint64(amount),
) )
if err != nil { if err != nil {
return err return err
@ -191,11 +196,11 @@ func (db *ordersDB) UpdateStoragenodeBandwidthSettle(ctx context.Context, storag
// GetBucketBandwidth gets total bucket bandwidth from period of time // GetBucketBandwidth gets total bucket bandwidth from period of time
func (db *ordersDB) GetBucketBandwidth(ctx context.Context, bucketID []byte, from, to time.Time) (_ int64, err error) { func (db *ordersDB) GetBucketBandwidth(ctx context.Context, bucketID []byte, from, to time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
pathElements := bytes.Split(bucketID, []byte("/")) projectID, bucketName, err := splitBucketID(bucketID)
bucketName, projectID := pathElements[1], pathElements[0]
var sum *int64 var sum *int64
query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?` query := `SELECT SUM(settled) FROM bucket_bandwidth_rollups WHERE bucket_name = ? AND project_id = ? AND interval_start > ? AND interval_start <= ?`
err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID, from, to).Scan(&sum) err = db.db.QueryRow(db.db.Rebind(query), bucketName, projectID[:], from, to).Scan(&sum)
if err == sql.ErrNoRows || sum == nil { if err == sql.ErrNoRows || sum == nil {
return 0, nil return 0, nil
} }

View File

@ -4,7 +4,6 @@
package satellitedb package satellitedb
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"time" "time"
@ -14,7 +13,6 @@ import (
"storj.io/storj/internal/memory" "storj.io/storj/internal/memory"
"storj.io/storj/pkg/accounting" "storj.io/storj/pkg/accounting"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
dbx "storj.io/storj/satellite/satellitedb/dbx" dbx "storj.io/storj/satellite/satellitedb/dbx"
) )
@ -33,9 +31,12 @@ func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time
var result []accounting.BucketTally var result []accounting.BucketTally
for bucketID, info := range bucketTallies { for bucketID, info := range bucketTallies {
bucketIDComponents := storj.SplitPath(bucketID) pid, bn, err := splitBucketID([]byte(bucketID))
bucketName := dbx.BucketStorageTally_BucketName([]byte(bucketIDComponents[1])) if err != nil {
projectID := dbx.BucketStorageTally_ProjectId([]byte(bucketIDComponents[0])) return nil, err
}
bucketName := dbx.BucketStorageTally_BucketName(bn)
projectID := dbx.BucketStorageTally_ProjectId(pid[:])
interval := dbx.BucketStorageTally_IntervalStart(intervalStart) interval := dbx.BucketStorageTally_IntervalStart(intervalStart)
inlineBytes := dbx.BucketStorageTally_Inline(uint64(info.InlineBytes)) inlineBytes := dbx.BucketStorageTally_Inline(uint64(info.InlineBytes))
remoteBytes := dbx.BucketStorageTally_Remote(uint64(info.RemoteBytes)) remoteBytes := dbx.BucketStorageTally_Remote(uint64(info.RemoteBytes))
@ -86,11 +87,13 @@ func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accou
// GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame // GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame
func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, bucketID []byte, from time.Time) (_ int64, err error) { func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, bucketID []byte, from time.Time) (_ int64, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
pathEl := bytes.Split(bucketID, []byte("/")) projectID, _, err := splitBucketID(bucketID)
_, projectID := pathEl[1], pathEl[0] if err != nil {
return 0, err
}
var sum *int64 var sum *int64
query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;` query := `SELECT SUM(allocated) FROM bucket_bandwidth_rollups WHERE project_id = ? AND action = ? AND interval_start > ?;`
err = db.db.QueryRow(db.db.Rebind(query), projectID, pb.PieceAction_GET, from).Scan(&sum) err = db.db.QueryRow(db.db.Rebind(query), projectID[:], pb.PieceAction_GET, from).Scan(&sum)
if err == sql.ErrNoRows || sum == nil { if err == sql.ErrNoRows || sum == nil {
return 0, nil return 0, nil
} }

View File

@ -34,7 +34,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ? WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?
GROUP BY action`) GROUP BY action`)
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), since, before) rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -64,7 +64,7 @@ func (db *usagerollups) GetProjectTotal(ctx context.Context, projectID uuid.UUID
bucketsTallies := make(map[string]*[]*dbx.BucketStorageTally) bucketsTallies := make(map[string]*[]*dbx.BucketStorageTally)
for _, bucket := range buckets { for _, bucket := range buckets {
storageTallies, err := storageQuery(ctx, storageTallies, err := storageQuery(ctx,
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())), dbx.BucketStorageTally_ProjectId(projectID[:]),
dbx.BucketStorageTally_BucketName([]byte(bucket)), dbx.BucketStorageTally_BucketName([]byte(bucket)),
dbx.BucketStorageTally_IntervalStart(since), dbx.BucketStorageTally_IntervalStart(since),
dbx.BucketStorageTally_IntervalStart(before)) dbx.BucketStorageTally_IntervalStart(before))
@ -124,7 +124,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
} }
// get bucket_bandwidth_rollups // get bucket_bandwidth_rollups
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before) rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -153,7 +153,7 @@ func (db *usagerollups) GetBucketUsageRollups(ctx context.Context, projectID uui
} }
bucketStorageTallies, err := storageQuery(ctx, bucketStorageTallies, err := storageQuery(ctx,
dbx.BucketStorageTally_ProjectId([]byte(projectID.String())), dbx.BucketStorageTally_ProjectId(projectID[:]),
dbx.BucketStorageTally_BucketName([]byte(bucket)), dbx.BucketStorageTally_BucketName([]byte(bucket)),
dbx.BucketStorageTally_IntervalStart(since), dbx.BucketStorageTally_IntervalStart(since),
dbx.BucketStorageTally_IntervalStart(before)) dbx.BucketStorageTally_IntervalStart(before))
@ -210,7 +210,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
countRow := db.db.QueryRowContext(ctx, countRow := db.db.QueryRowContext(ctx,
countQuery, countQuery,
[]byte(projectID.String()), projectID[:],
since, before, since, before,
search) search)
@ -234,7 +234,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
bucketRows, err := db.db.QueryContext(ctx, bucketRows, err := db.db.QueryContext(ctx,
bucketsQuery, bucketsQuery,
[]byte(projectID.String()), projectID[:],
since, before, since, before,
search, search,
page.Limit, page.Limit,
@ -277,7 +277,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
} }
// get bucket_bandwidth_rollups // get bucket_bandwidth_rollups
rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, []byte(projectID.String()), []byte(bucket), since, before) rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -301,7 +301,7 @@ func (db *usagerollups) GetBucketTotals(ctx context.Context, projectID uuid.UUID
bucketUsage.Egress = memory.Size(totalEgress).GB() bucketUsage.Egress = memory.Size(totalEgress).GB()
storageRow := db.db.QueryRowContext(ctx, storageQuery, []byte(projectID.String()), []byte(bucket), since, before) storageRow := db.db.QueryRowContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -338,7 +338,7 @@ func (db *usagerollups) getBuckets(ctx context.Context, projectID uuid.UUID, sin
FROM bucket_bandwidth_rollups FROM bucket_bandwidth_rollups
WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`) WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`)
bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, []byte(projectID.String()), since, before) bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:], since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,6 +4,7 @@
package satellitedb package satellitedb
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"github.com/skyrings/skyring-common/tools/uuid" "github.com/skyrings/skyring-common/tools/uuid"
@ -24,6 +25,17 @@ func bytesToUUID(data []byte) (uuid.UUID, error) {
return id, nil return id, nil
} }
// splitBucketID takes a bucketID, splits on /, and returns a projectID and bucketName
func splitBucketID(bucketID []byte) (projectID *uuid.UUID, bucketName []byte, err error) {
pathElements := bytes.Split(bucketID, []byte("/"))
bucketName = pathElements[1]
projectID, err = uuid.Parse(string(pathElements[0]))
if err != nil {
return nil, nil, err
}
return projectID, bucketName, nil
}
type postgresNodeIDList storj.NodeIDList type postgresNodeIDList storj.NodeIDList
// Value converts a NodeIDList to a postgres array // Value converts a NodeIDList to a postgres array

View File

@ -36,6 +36,32 @@ func TestBytesToUUID(t *testing.T) {
}) })
} }
func TestSpliteBucketID(t *testing.T) {
t.Run("Invalid input", func(t *testing.T) {
str := "not UUID string/bucket1"
bytes := []byte(str)
_, _, err := splitBucketID(bytes)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("Valid input", func(t *testing.T) {
expectedBucketID, err := uuid.Parse("bb6218e3-4b4a-4819-abbb-fa68538e33c0")
expectedBucketName := "bucket1"
assert.NoError(t, err)
str := expectedBucketID.String() + "/" + expectedBucketName
bucketID, bucketName, err := splitBucketID([]byte(str))
assert.NoError(t, err)
assert.Equal(t, bucketID, expectedBucketID)
assert.Equal(t, bucketName, []byte(expectedBucketName))
})
}
func TestPostgresNodeIDsArray(t *testing.T) { func TestPostgresNodeIDsArray(t *testing.T) {
ids := make(storj.NodeIDList, 10) ids := make(storj.NodeIDList, 10)
for i := range ids { for i := range ids {