satellite/payments/stripecoinpayments: remove usage price overriding

Project usage price overriding has been removed because it produces
incorrect results when tested. It should not be re-implemented until
the issues it causes are resolved.

Change-Id: Ic92eff374c9af4fea3bf32782a72303a7978b055
This commit is contained in:
Jeremy Wharton 2023-02-22 23:08:29 -06:00 committed by Maximillian von Briesen
parent 16d3fcde70
commit fb5997484e
8 changed files with 198 additions and 518 deletions

View File

@ -236,9 +236,6 @@ type ProjectAccounting interface {
GetProjectLimits(ctx context.Context, projectID uuid.UUID) (ProjectLimits, error)
// GetProjectTotal returns project usage summary for specified period of time.
GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (*ProjectUsage, error)
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]ProjectUsage, err error)
// GetProjectObjectsSegments returns project objects and segments for specified period of time.
GetProjectObjectsSegments(ctx context.Context, projectID uuid.UUID) (*ProjectObjectsSegments, error)
// GetBucketUsageRollups returns usage rollup per each bucket for specified period of time.

View File

@ -3139,12 +3139,7 @@ func (payment Payments) Purchase(ctx context.Context, price int64, desc string,
func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) {
defer mon.Task()(&ctx)(&err)
user, err := GetUser(ctx)
if err != nil {
return nil, Error.Wrap(err)
}
model := payment.service.accounts.GetProjectUsagePriceModel(string(user.UserAgent))
model := payment.service.accounts.GetProjectUsagePriceModel()
return &model, nil
}

View File

@ -29,8 +29,8 @@ type Accounts interface {
// ProjectCharges returns how much money current user will be charged for each project.
ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error)
// GetProjectUsagePriceModel returns the project usage price model for a partner name.
GetProjectUsagePriceModel(partner string) ProjectUsagePriceModel
// GetProjectUsagePriceModel returns the project usage price model.
GetProjectUsagePriceModel() ProjectUsagePriceModel
// CheckProjectInvoicingStatus returns error if for the given project there are outstanding project records and/or usage
// which have not been applied/invoiced yet (meaning sent over to stripe).

View File

@ -11,7 +11,6 @@ import (
"github.com/zeebo/errs"
"storj.io/common/uuid"
"storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/payments"
)
@ -121,47 +120,28 @@ func (accounts *accounts) ProjectCharges(ctx context.Context, userID uuid.UUID,
}
for _, project := range projects {
totalUsage := accounting.ProjectUsage{Since: since, Before: before}
usages, err := accounts.service.usageDB.GetProjectTotalByPartner(ctx, project.ID, accounts.service.partnerNames, since, before)
usage, err := accounts.service.usageDB.GetProjectTotal(ctx, project.ID, since, before)
if err != nil {
return nil, Error.Wrap(err)
return charges, Error.Wrap(err)
}
var totalPrice projectUsagePrice
for partner, usage := range usages {
priceModel := accounts.GetProjectUsagePriceModel(partner)
price := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount, priceModel)
totalPrice.Egress = totalPrice.Egress.Add(price.Egress)
totalPrice.Segments = totalPrice.Segments.Add(price.Segments)
totalPrice.Storage = totalPrice.Storage.Add(price.Storage)
totalUsage.Egress += usage.Egress
totalUsage.ObjectCount += usage.ObjectCount
totalUsage.SegmentCount += usage.SegmentCount
totalUsage.Storage += usage.Storage
}
projectPrice := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount)
charges = append(charges, payments.ProjectCharge{
ProjectUsage: totalUsage,
ProjectUsage: *usage,
ProjectID: project.ID,
Egress: totalPrice.Egress.IntPart(),
SegmentCount: totalPrice.Segments.IntPart(),
StorageGbHrs: totalPrice.Storage.IntPart(),
Egress: projectPrice.Egress.IntPart(),
SegmentCount: projectPrice.Segments.IntPart(),
StorageGbHrs: projectPrice.Storage.IntPart(),
})
}
return charges, nil
}
// GetProjectUsagePriceModel returns the project usage price model for a partner name.
func (accounts *accounts) GetProjectUsagePriceModel(partner string) payments.ProjectUsagePriceModel {
if override, ok := accounts.service.usagePriceOverrides[partner]; ok {
return override
}
// GetProjectUsagePriceModel returns the project usage price model.
func (accounts *accounts) GetProjectUsagePriceModel() payments.ProjectUsagePriceModel {
return accounts.service.usagePrices
}

View File

@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
"strings"
"time"
@ -487,12 +486,7 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return true, nil
}
usages, err := service.usageDB.GetProjectTotalByPartner(ctx, record.ProjectID, service.partnerNames, record.PeriodStart, record.PeriodEnd)
if err != nil {
return false, err
}
items := service.InvoiceItemsFromProjectUsage(projName, usages)
items := service.InvoiceItemsFromProjectRecord(projName, record)
for _, item := range items {
item.Currency = stripe.String(string(stripe.CurrencyUSD))
item.Customer = stripe.String(cusID)
@ -507,50 +501,28 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return false, nil
}
// InvoiceItemsFromProjectUsage calculates Stripe invoice item from project usage.
func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsages map[string]accounting.ProjectUsage) (result []*stripe.InvoiceItemParams) {
var partners []string
if len(partnerUsages) == 0 {
partners = []string{""}
partnerUsages = map[string]accounting.ProjectUsage{"": {}}
} else {
for partner := range partnerUsages {
partners = append(partners, partner)
}
sort.Strings(partners)
}
// InvoiceItemsFromProjectRecord calculates Stripe invoice item from project record.
func (service *Service) InvoiceItemsFromProjectRecord(projName string, record ProjectRecord) (result []*stripe.InvoiceItemParams) {
projectItem := &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Storage (MB-Month)", projName))
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(record.Storage).IntPart())
storagePrice, _ := service.usagePrices.StorageMBMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
result = append(result, projectItem)
for _, partner := range partners {
usage := partnerUsages[partner]
priceModel := service.Accounts().GetProjectUsagePriceModel(partner)
prefix := "Project " + projName
if partner != "" {
prefix += " (" + partner + ")"
}
projectItem := &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Storage (MB-Month)")
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(usage.Storage).IntPart())
storagePrice, _ := priceModel.StorageMBMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Egress Bandwidth (MB)")
projectItem.Quantity = stripe.Int64(egressMBDecimal(usage.Egress).IntPart())
egressPrice, _ := priceModel.EgressMBCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Fee (Segment-Month)")
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(usage.SegmentCount).IntPart())
segmentPrice, _ := priceModel.SegmentMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
result = append(result, projectItem)
}
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Egress Bandwidth (MB)", projName))
projectItem.Quantity = stripe.Int64(egressMBDecimal(record.Egress).IntPart())
egressPrice, _ := service.usagePrices.EgressMBCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Fee (Segment-Month)", projName))
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(record.Segments).IntPart())
segmentPrice, _ := service.usagePrices.SegmentMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
result = append(result, projectItem)
service.log.Info("invoice items", zap.Any("result", result))
return result
@ -808,11 +780,11 @@ func (price projectUsagePrice) TotalInt64() int64 {
}
// calculateProjectUsagePrice calculate project usage price.
func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64, pricing payments.ProjectUsagePriceModel) projectUsagePrice {
func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64) projectUsagePrice {
return projectUsagePrice{
Storage: pricing.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0),
Egress: pricing.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0),
Segments: pricing.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0),
Storage: service.usagePrices.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0),
Egress: service.usagePrices.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0),
Segments: service.usagePrices.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0),
}
}

View File

@ -5,8 +5,6 @@ package stripecoinpayments_test
import (
"context"
"fmt"
"math"
"strconv"
"testing"
"time"
@ -18,7 +16,6 @@ import (
"storj.io/common/currency"
"storj.io/common/memory"
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/common/testcontext"
"storj.io/common/testrand"
"storj.io/common/uuid"
@ -28,9 +25,7 @@ import (
"storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/console"
"storj.io/storj/satellite/metabase"
"storj.io/storj/satellite/payments"
"storj.io/storj/satellite/payments/billing"
"storj.io/storj/satellite/payments/paymentsconfig"
"storj.io/storj/satellite/payments/stripecoinpayments"
)
@ -227,106 +222,67 @@ func TestService_ProjectsWithMembers(t *testing.T) {
})
}
func TestService_InvoiceItemsFromProjectUsage(t *testing.T) {
const (
projectName = "my-project"
partnerName = "partner"
noOverridePartnerName = "no-override"
hoursPerMonth = 24 * 30
bytesPerMegabyte = int64(memory.MB / memory.B)
byteHoursPerMBMonth = hoursPerMonth * bytesPerMegabyte
)
var (
defaultPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "1",
EgressTB: "2",
Segment: "3",
}
partnerPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "4",
EgressTB: "5",
Segment: "6",
}
)
defaultModel, err := defaultPrice.ToModel()
require.NoError(t, err)
partnerModel, err := partnerPrice.ToModel()
require.NoError(t, err)
func TestService_InvoiceItemsFromProjectRecord(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Payments.UsagePrice = defaultPrice
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
partnerName: partnerPrice,
})
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
usage := map[string]accounting.ProjectUsage{
"": {
Storage: 10000000000, // Byte-hours
Egress: 123 * memory.GB.Int64(), // Bytes
SegmentCount: 200000, // Segment-Hours
satellite := planet.Satellites[0]
// these numbers are fraction of cents, not of dollars.
expectedStoragePrice := 0.001
expectedEgressPrice := 0.0045
expectedSegmentPrice := 0.00022
type TestCase struct {
Storage float64
Egress int64
Segments float64
StorageQuantity int64
EgressQuantity int64
SegmentsQuantity int64
}
testCases := []TestCase{
{}, // all zeros
{
Storage: 10000000000, // Byte-Hours
// storage quantity is calculated to Megabyte-Months
// (10000000000 / 1000000) Byte-Hours to Megabytes-Hours
// round(10000 / 720) Megabytes-Hours to Megabyte-Months, 720 - hours in month
StorageQuantity: 14, // Megabyte-Months
},
partnerName: {
Storage: 20000000000,
Egress: 456 * memory.GB.Int64(),
SegmentCount: 400000,
{
Egress: 134 * memory.GB.Int64(), // Bytes
// egress quantity is calculated to Megabytes
// (134000000000 / 1000000) Bytes to Megabytes
EgressQuantity: 134000, // Megabytes
},
noOverridePartnerName: {
Storage: 30000000000,
Egress: 789 * memory.GB.Int64(),
SegmentCount: 600000,
{
Segments: 400000, // Segment-Hours
// object quantity is calculated to Segment-Months
// round(400000 / 720) Segment-Hours to Segment-Months, 720 - hours in month
SegmentsQuantity: 556, // Segment-Months
},
}
items := planet.Satellites[0].API.Payments.StripeService.InvoiceItemsFromProjectUsage(projectName, usage)
require.Len(t, items, len(usage)*3)
for _, tc := range testCases {
record := stripecoinpayments.ProjectRecord{
Storage: tc.Storage,
Egress: tc.Egress,
Segments: tc.Segments,
}
for i, tt := range []struct {
name string
partner string
priceModel payments.ProjectUsagePriceModel
}{
{"default pricing - no partner", "", defaultModel},
{"default pricing - no override for partner", noOverridePartnerName, defaultModel},
{"partner pricing", partnerName, partnerModel},
} {
t.Run(tt.name, func(t *testing.T) {
prefix := "Project " + projectName
if tt.partner != "" {
prefix += " (" + tt.partner + ")"
}
items := satellite.API.Payments.StripeService.InvoiceItemsFromProjectRecord("project name", record)
usage := usage[tt.partner]
expectedStorageQuantity := int64(math.Round(usage.Storage / float64(byteHoursPerMBMonth)))
expectedEgressQuantity := int64(math.Round(float64(usage.Egress) / float64(bytesPerMegabyte)))
expectedSegmentQuantity := int64(math.Round(usage.SegmentCount / hoursPerMonth))
require.Equal(t, tc.StorageQuantity, *items[0].Quantity)
require.Equal(t, expectedStoragePrice, *items[0].UnitAmountDecimal)
items := items[i*3 : (i*3)+3]
for _, item := range items {
require.NotNil(t, item)
}
require.Equal(t, tc.EgressQuantity, *items[1].Quantity)
require.Equal(t, expectedEgressPrice, *items[1].UnitAmountDecimal)
require.Equal(t, prefix+" - Segment Storage (MB-Month)", *items[0].Description)
require.Equal(t, expectedStorageQuantity, *items[0].Quantity)
storage, _ := tt.priceModel.StorageMBMonthCents.Float64()
require.Equal(t, storage, *items[0].UnitAmountDecimal)
require.Equal(t, prefix+" - Egress Bandwidth (MB)", *items[1].Description)
require.Equal(t, expectedEgressQuantity, *items[1].Quantity)
egress, _ := tt.priceModel.EgressMBCents.Float64()
require.Equal(t, egress, *items[1].UnitAmountDecimal)
require.Equal(t, prefix+" - Segment Fee (Segment-Month)", *items[2].Description)
require.Equal(t, expectedSegmentQuantity, *items[2].Quantity)
segment, _ := tt.priceModel.SegmentMonthCents.Float64()
require.Equal(t, segment, *items[2].UnitAmountDecimal)
})
require.Equal(t, tc.SegmentsQuantity, *items[2].Quantity)
require.Equal(t, expectedSegmentPrice, *items[2].UnitAmountDecimal)
}
})
}
@ -535,100 +491,6 @@ func generateProjectStorage(ctx context.Context, tb testing.TB, db satellite.DB,
require.NoError(tb, err)
}
func TestProjectUsagePrice(t *testing.T) {
var (
defaultPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "1",
EgressTB: "2",
Segment: "3",
}
partnerName = "partner"
partnerPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "4",
EgressTB: "5",
Segment: "6",
}
)
defaultModel, err := defaultPrice.ToModel()
require.NoError(t, err)
partnerModel, err := partnerPrice.ToModel()
require.NoError(t, err)
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Payments.UsagePrice = defaultPrice
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
partnerName: partnerPrice,
})
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
sat := planet.Satellites[0]
// pick a specific date so that it doesn't fail if it's the last day of the month
// keep month + 1 because user needs to be created before calculation
period := time.Date(time.Now().Year(), time.Now().Month()+1, 20, 0, 0, 0, 0, time.UTC)
sat.API.Payments.StripeService.SetNow(func() time.Time {
return time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)
})
for i, tt := range []struct {
name string
userAgent []byte
expectedPrice payments.ProjectUsagePriceModel
}{
{"default pricing", nil, defaultModel},
{"default pricing - user agent is not valid partner name", []byte("invalid/v0.0"), defaultModel},
{"partner pricing - user agent is partner name", []byte(partnerName), partnerModel},
{"partner pricing - user agent prefixed with partner name", []byte(partnerName + " invalid/v0.0"), partnerModel},
} {
t.Run(tt.name, func(t *testing.T) {
user, err := sat.AddUser(ctx, console.CreateUser{
FullName: "Test User",
Email: fmt.Sprintf("user%d@mail.test", i),
UserAgent: tt.userAgent,
}, 1)
require.NoError(t, err)
project, err := sat.AddProject(ctx, user.ID, "testproject")
require.NoError(t, err)
bucket, err := sat.DB.Buckets().CreateBucket(ctx, storj.Bucket{
ID: testrand.UUID(),
Name: testrand.BucketName(),
ProjectID: project.ID,
UserAgent: tt.userAgent,
})
require.NoError(t, err)
err = sat.DB.Orders().UpdateBucketBandwidthSettle(ctx, project.ID, []byte(bucket.Name),
pb.PieceAction_GET, memory.TB.Int64(), 0, period)
require.NoError(t, err)
err = sat.API.Payments.StripeService.PrepareInvoiceProjectRecords(ctx, period)
require.NoError(t, err)
err = sat.API.Payments.StripeService.InvoiceApplyProjectRecords(ctx, period)
require.NoError(t, err)
cusID, err := sat.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID)
require.NoError(t, err)
items := getCustomerInvoiceItems(sat.API.Payments.StripeClient, cusID)
require.Len(t, items, 3)
storage, _ := tt.expectedPrice.StorageMBMonthCents.Float64()
require.Equal(t, storage, items[0].UnitAmountDecimal)
egress, _ := tt.expectedPrice.EgressMBCents.Float64()
require.Equal(t, egress, items[1].UnitAmountDecimal)
segment, _ := tt.expectedPrice.SegmentMonthCents.Float64()
require.Equal(t, segment, items[2].UnitAmountDecimal)
})
}
})
}
func TestPayInvoicesSkipDue(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,

View File

@ -16,7 +16,6 @@ import (
"storj.io/common/memory"
"storj.io/common/pb"
"storj.io/common/useragent"
"storj.io/common/uuid"
"storj.io/private/dbutil"
"storj.io/private/dbutil/pgutil"
@ -506,21 +505,7 @@ func (db *ProjectAccounting) GetProjectSegmentLimit(ctx context.Context, project
}
// GetProjectTotal retrieves project usage for a given period.
func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ *accounting.ProjectUsage, err error) {
defer mon.Task()(&ctx)(&err)
usages, err := db.GetProjectTotalByPartner(ctx, projectID, nil, since, before)
if err != nil {
return nil, err
}
if usage, ok := usages[""]; ok {
return &usage, nil
}
return &accounting.ProjectUsage{Since: since, Before: before}, nil
}
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]accounting.ProjectUsage, err error) {
func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (usage *accounting.ProjectUsage, err error) {
defer mon.Task()(&ctx)(&err)
since = timeTruncateDown(since)
bucketNames, err := db.getBucketsSinceAndBefore(ctx, projectID, since, before)
@ -546,54 +531,16 @@ func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, proje
ORDER BY bucket_storage_tallies.interval_start DESC
`)
totalEgressQuery := db.db.Rebind(`
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
bucket_name = ? AND
project_id = ? AND
interval_start >= ? AND
interval_start < ? AND
action = ?;
`)
usages = make(map[string]accounting.ProjectUsage)
bucketsTallies := make(map[string][]*accounting.BucketStorageTally)
for _, bucket := range bucketNames {
userAgentRow, err := db.db.Get_BucketMetainfo_UserAgent_By_ProjectId_And_Name(ctx,
dbx.BucketMetainfo_ProjectId(projectID[:]),
dbx.BucketMetainfo_Name([]byte(bucket)))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
var partner string
if userAgentRow != nil && userAgentRow.UserAgent != nil {
entries, err := useragent.ParseEntries(userAgentRow.UserAgent)
if err != nil {
return nil, err
}
for _, iterPartner := range partnerNames {
if entries[0].Product == iterPartner {
partner = iterPartner
break
}
}
}
if _, ok := usages[partner]; !ok {
usages[partner] = accounting.ProjectUsage{Since: since, Before: before}
}
usage := usages[partner]
storageTallies := make([]*accounting.BucketStorageTally, 0)
storageTalliesRows, err := db.db.QueryContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
if err != nil {
return nil, err
}
var prevTally *accounting.BucketStorageTally
// generating tallies for each bucket name.
for storageTalliesRows.Next() {
tally := accounting.BucketStorageTally{}
@ -606,17 +553,8 @@ func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, proje
tally.TotalBytes = inline + remote
}
if prevTally == nil {
prevTally = &tally
continue
}
hours := prevTally.IntervalStart.Sub(tally.IntervalStart).Hours()
usage.Storage += memory.Size(tally.TotalBytes).Float64() * hours
usage.SegmentCount += float64(tally.TotalSegmentCount) * hours
usage.ObjectCount += float64(tally.ObjectCount) * hours
prevTally = &tally
tally.BucketName = bucket
storageTallies = append(storageTallies, &tally)
}
err = errs.Combine(storageTalliesRows.Err(), storageTalliesRows.Close())
@ -624,21 +562,53 @@ func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, proje
return nil, err
}
totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, []byte(bucket), projectID[:], since, before, pb.PieceAction_GET)
if err != nil {
return nil, err
}
var egress int64
if err = totalEgressRow.Scan(&egress); err != nil {
return nil, err
}
usage.Egress += egress
usages[partner] = usage
bucketsTallies[bucket] = storageTallies
}
return usages, nil
totalEgress, err := db.getTotalEgress(ctx, projectID, since, before)
if err != nil {
return nil, err
}
usage = new(accounting.ProjectUsage)
usage.Egress = memory.Size(totalEgress).Int64()
// sum up storage, objects, and segments
for _, tallies := range bucketsTallies {
for i := len(tallies) - 1; i > 0; i-- {
current := (tallies)[i]
hours := (tallies)[i-1].IntervalStart.Sub(current.IntervalStart).Hours()
usage.Storage += memory.Size(current.Bytes()).Float64() * hours
usage.SegmentCount += float64(current.TotalSegmentCount) * hours
usage.ObjectCount += float64(current.ObjectCount) * hours
}
}
usage.Since = since
usage.Before = before
return usage, nil
}
// getTotalEgress returns total egress (settled + inline) of each bucket_bandwidth_rollup
// in selected time period, project id.
// only process PieceAction_GET.
func (db *ProjectAccounting) getTotalEgress(ctx context.Context, projectID uuid.UUID, since, before time.Time) (totalEgress int64, err error) {
totalEgressQuery := db.db.Rebind(`
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
project_id = ? AND
interval_start >= ? AND
interval_start < ? AND
action = ?;
`)
totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, projectID[:], since, before, pb.PieceAction_GET)
err = totalEgressRow.Scan(&totalEgress)
return totalEgress, err
}
// GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period.

View File

@ -11,7 +11,6 @@ import (
"storj.io/common/memory"
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/common/testcontext"
"storj.io/common/testrand"
"storj.io/storj/private/testplanet"
@ -175,169 +174,74 @@ func Test_GetSingleBucketRollup(t *testing.T) {
)
}
func Test_GetProjectTotalByPartner(t *testing.T) {
const (
epsilon = 1e-8
usagePeriod = time.Hour
tallyRollupCount = 2
)
since := time.Time{}
before := since.Add(2 * usagePeriod)
func Test_GetProjectTotal(t *testing.T) {
testplanet.Run(t, testplanet.Config{SatelliteCount: 1, StorageNodeCount: 1},
func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
sat := planet.Satellites[0]
bucketName := testrand.BucketName()
projectID := testrand.UUID()
user, err := sat.AddUser(ctx, console.CreateUser{
FullName: "Test User",
Email: "user@mail.test",
}, 1)
db := planet.Satellites[0].DB
// The 3rd tally is only present to prevent CreateStorageTally from skipping the 2nd.
var tallies []accounting.BucketStorageTally
for i := 0; i < 3; i++ {
tally := accounting.BucketStorageTally{
BucketName: bucketName,
ProjectID: projectID,
IntervalStart: time.Time{}.Add(time.Duration(i) * time.Hour),
TotalBytes: int64(testrand.Intn(1000)),
ObjectCount: int64(testrand.Intn(1000)),
TotalSegmentCount: int64(testrand.Intn(1000)),
}
tallies = append(tallies, tally)
require.NoError(t, db.ProjectAccounting().CreateStorageTally(ctx, tally))
}
var rollups []orders.BucketBandwidthRollup
var expectedEgress int64
for i := 0; i < 2; i++ {
rollup := orders.BucketBandwidthRollup{
ProjectID: projectID,
BucketName: bucketName,
Action: pb.PieceAction_GET,
IntervalStart: tallies[i].IntervalStart,
Inline: int64(testrand.Intn(1000)),
Settled: int64(testrand.Intn(1000)),
}
rollups = append(rollups, rollup)
expectedEgress += rollup.Inline + rollup.Settled
}
require.NoError(t, db.Orders().UpdateBandwidthBatch(ctx, rollups))
usage, err := db.ProjectAccounting().GetProjectTotal(ctx, projectID, tallies[0].IntervalStart, tallies[2].IntervalStart.Add(time.Minute))
require.NoError(t, err)
project, err := sat.AddProject(ctx, user.ID, "testproject")
const epsilon = 1e-8
require.InDelta(t, usage.Storage, float64(tallies[0].Bytes()+tallies[1].Bytes()), epsilon)
require.InDelta(t, usage.SegmentCount, float64(tallies[0].TotalSegmentCount+tallies[1].TotalSegmentCount), epsilon)
require.InDelta(t, usage.ObjectCount, float64(tallies[0].ObjectCount+tallies[1].ObjectCount), epsilon)
require.Equal(t, usage.Egress, expectedEgress)
require.Equal(t, usage.Since, tallies[0].IntervalStart)
require.Equal(t, usage.Before, tallies[2].IntervalStart.Add(time.Minute))
// Ensure that GetProjectTotal treats the 'before' arg as exclusive
usage, err = db.ProjectAccounting().GetProjectTotal(ctx, projectID, tallies[0].IntervalStart, tallies[2].IntervalStart)
require.NoError(t, err)
require.InDelta(t, usage.Storage, float64(tallies[0].Bytes()), epsilon)
require.InDelta(t, usage.SegmentCount, float64(tallies[0].TotalSegmentCount), epsilon)
require.InDelta(t, usage.ObjectCount, float64(tallies[0].ObjectCount), epsilon)
require.Equal(t, usage.Egress, expectedEgress)
require.Equal(t, usage.Since, tallies[0].IntervalStart)
require.Equal(t, usage.Before, tallies[2].IntervalStart)
type expectedTotal struct {
storage float64
segments float64
objects float64
egress int64
}
expectedTotals := make(map[string]expectedTotal)
var beforeTotal expectedTotal
requireTotal := func(t *testing.T, expected expectedTotal, actual accounting.ProjectUsage) {
require.InDelta(t, expected.storage, actual.Storage, epsilon)
require.InDelta(t, expected.segments, actual.SegmentCount, epsilon)
require.InDelta(t, expected.objects, actual.ObjectCount, epsilon)
require.Equal(t, expected.egress, actual.Egress)
require.Equal(t, since, actual.Since)
require.Equal(t, before, actual.Before)
}
partnerNames := []string{"", "partner1", "partner2"}
for _, name := range partnerNames {
total := expectedTotal{}
bucket := storj.Bucket{
ID: testrand.UUID(),
Name: testrand.BucketName(),
ProjectID: project.ID,
}
if name != "" {
bucket.UserAgent = []byte(name)
}
_, err := sat.DB.Buckets().CreateBucket(ctx, bucket)
require.NoError(t, err)
// We use multiple tallies and rollups to ensure that
// GetProjectTotalByPartner is capable of summing them.
for i := 0; i <= tallyRollupCount; i++ {
tally := accounting.BucketStorageTally{
BucketName: bucket.Name,
ProjectID: project.ID,
IntervalStart: since.Add(time.Duration(i) * usagePeriod / tallyRollupCount),
TotalBytes: int64(testrand.Intn(1000)),
ObjectCount: int64(testrand.Intn(1000)),
TotalSegmentCount: int64(testrand.Intn(1000)),
}
require.NoError(t, sat.DB.ProjectAccounting().CreateStorageTally(ctx, tally))
// The last tally's usage data is unused.
usageHours := (usagePeriod / tallyRollupCount).Hours()
if i < tallyRollupCount {
total.storage += float64(tally.Bytes()) * usageHours
total.objects += float64(tally.ObjectCount) * usageHours
total.segments += float64(tally.TotalSegmentCount) * usageHours
}
if i < tallyRollupCount-1 {
beforeTotal.storage += float64(tally.Bytes()) * usageHours
beforeTotal.objects += float64(tally.ObjectCount) * usageHours
beforeTotal.segments += float64(tally.TotalSegmentCount) * usageHours
}
}
var rollups []orders.BucketBandwidthRollup
for i := 0; i < tallyRollupCount; i++ {
rollup := orders.BucketBandwidthRollup{
BucketName: bucket.Name,
ProjectID: project.ID,
Action: pb.PieceAction_GET,
IntervalStart: since.Add(time.Duration(i) * usagePeriod / tallyRollupCount),
Inline: int64(testrand.Intn(1000)),
Settled: int64(testrand.Intn(1000)),
}
rollups = append(rollups, rollup)
total.egress += rollup.Inline + rollup.Settled
if i < tallyRollupCount {
beforeTotal.egress += rollup.Inline + rollup.Settled
}
}
require.NoError(t, sat.DB.Orders().UpdateBandwidthBatch(ctx, rollups))
expectedTotals[name] = total
}
t.Run("sum all partner usages", func(t *testing.T) {
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
require.NoError(t, err)
require.Len(t, usages, 1)
require.Contains(t, usages, "")
var summedTotal expectedTotal
for _, total := range expectedTotals {
summedTotal.storage += total.storage
summedTotal.segments += total.segments
summedTotal.objects += total.objects
summedTotal.egress += total.egress
}
requireTotal(t, summedTotal, usages[""])
})
t.Run("individual partner usages", func(t *testing.T) {
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, partnerNames, since, before)
require.NoError(t, err)
require.Len(t, usages, len(expectedTotals))
for _, name := range partnerNames {
require.Contains(t, usages, name)
}
for partner, usage := range usages {
requireTotal(t, expectedTotals[partner], usage)
}
})
t.Run("select one partner usage and sum remaining usages", func(t *testing.T) {
partner := partnerNames[len(partnerNames)-1]
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, []string{partner}, since, before)
require.NoError(t, err)
require.Len(t, usages, 2)
require.Contains(t, usages, "")
require.Contains(t, usages, partner)
var summedTotal expectedTotal
for _, partner := range partnerNames[:len(partnerNames)-1] {
summedTotal.storage += expectedTotals[partner].storage
summedTotal.segments += expectedTotals[partner].segments
summedTotal.objects += expectedTotals[partner].objects
summedTotal.egress += expectedTotals[partner].egress
}
requireTotal(t, expectedTotals[partner], usages[partner])
requireTotal(t, summedTotal, usages[""])
})
t.Run("ensure before is exclusive", func(t *testing.T) {
before = since.Add(usagePeriod)
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
require.NoError(t, err)
require.Len(t, usages, 1)
require.Contains(t, usages, "")
requireTotal(t, beforeTotal, usages[""])
})
usage, err = db.ProjectAccounting().GetProjectTotal(ctx, projectID, rollups[0].IntervalStart, rollups[1].IntervalStart)
require.NoError(t, err)
require.Zero(t, usage.Storage)
require.Zero(t, usage.SegmentCount)
require.Zero(t, usage.ObjectCount)
require.Equal(t, usage.Egress, rollups[0].Inline+rollups[0].Settled)
require.Equal(t, usage.Since, rollups[0].IntervalStart)
require.Equal(t, usage.Before, rollups[1].IntervalStart)
},
)
}