satellite/payments/stripecoinpayments: undo price override removal

Commit fb59974 disabled usage price overrides because of a failing
test. This change reenables it while resolving the issue that caused
the test to fail.

The previous version of the test passed Gerrit verification and was
merged, but it failed for the primary Jenkins pipeline after merge.

This is due to a difference in how the Jenkins build runs Cockroach
and Postgres for each pipeline.

This commit rewrites the test to be safe for concurrent execution by
ensuring any mutable variables are defined within each test so that
shared state across tests is reduced.

Change-Id: Ia4566c9cd2d698afdb2caa4b7e2808b17e18de4e
This commit is contained in:
Jeremy Wharton 2023-02-23 10:27:37 -06:00 committed by Storj Robot
parent 17ec326fd4
commit cbbd5ab1ef
8 changed files with 543 additions and 152 deletions

View File

@ -236,6 +236,9 @@ type ProjectAccounting interface {
GetProjectLimits(ctx context.Context, projectID uuid.UUID) (ProjectLimits, error) GetProjectLimits(ctx context.Context, projectID uuid.UUID) (ProjectLimits, error)
// GetProjectTotal returns project usage summary for specified period of time. // GetProjectTotal returns project usage summary for specified period of time.
GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (*ProjectUsage, error) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (*ProjectUsage, error)
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]ProjectUsage, err error)
// GetProjectObjectsSegments returns project objects and segments for specified period of time. // GetProjectObjectsSegments returns project objects and segments for specified period of time.
GetProjectObjectsSegments(ctx context.Context, projectID uuid.UUID) (*ProjectObjectsSegments, error) GetProjectObjectsSegments(ctx context.Context, projectID uuid.UUID) (*ProjectObjectsSegments, error)
// GetBucketUsageRollups returns usage rollup per each bucket for specified period of time. // GetBucketUsageRollups returns usage rollup per each bucket for specified period of time.

View File

@ -3139,7 +3139,12 @@ func (payment Payments) Purchase(ctx context.Context, price int64, desc string,
func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) { func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
model := payment.service.accounts.GetProjectUsagePriceModel() user, err := GetUser(ctx)
if err != nil {
return nil, Error.Wrap(err)
}
model := payment.service.accounts.GetProjectUsagePriceModel(string(user.UserAgent))
return &model, nil return &model, nil
} }

View File

@ -29,8 +29,8 @@ type Accounts interface {
// ProjectCharges returns how much money current user will be charged for each project. // ProjectCharges returns how much money current user will be charged for each project.
ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error) ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error)
// GetProjectUsagePriceModel returns the project usage price model. // GetProjectUsagePriceModel returns the project usage price model for a partner name.
GetProjectUsagePriceModel() ProjectUsagePriceModel GetProjectUsagePriceModel(partner string) ProjectUsagePriceModel
// CheckProjectInvoicingStatus returns error if for the given project there are outstanding project records and/or usage // CheckProjectInvoicingStatus returns error if for the given project there are outstanding project records and/or usage
// which have not been applied/invoiced yet (meaning sent over to stripe). // which have not been applied/invoiced yet (meaning sent over to stripe).

View File

@ -11,6 +11,7 @@ import (
"github.com/zeebo/errs" "github.com/zeebo/errs"
"storj.io/common/uuid" "storj.io/common/uuid"
"storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/payments" "storj.io/storj/satellite/payments"
) )
@ -120,28 +121,47 @@ func (accounts *accounts) ProjectCharges(ctx context.Context, userID uuid.UUID,
} }
for _, project := range projects { for _, project := range projects {
usage, err := accounts.service.usageDB.GetProjectTotal(ctx, project.ID, since, before) totalUsage := accounting.ProjectUsage{Since: since, Before: before}
usages, err := accounts.service.usageDB.GetProjectTotalByPartner(ctx, project.ID, accounts.service.partnerNames, since, before)
if err != nil { if err != nil {
return charges, Error.Wrap(err) return nil, Error.Wrap(err)
} }
projectPrice := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount) var totalPrice projectUsagePrice
for partner, usage := range usages {
priceModel := accounts.GetProjectUsagePriceModel(partner)
price := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount, priceModel)
totalPrice.Egress = totalPrice.Egress.Add(price.Egress)
totalPrice.Segments = totalPrice.Segments.Add(price.Segments)
totalPrice.Storage = totalPrice.Storage.Add(price.Storage)
totalUsage.Egress += usage.Egress
totalUsage.ObjectCount += usage.ObjectCount
totalUsage.SegmentCount += usage.SegmentCount
totalUsage.Storage += usage.Storage
}
charges = append(charges, payments.ProjectCharge{ charges = append(charges, payments.ProjectCharge{
ProjectUsage: *usage, ProjectUsage: totalUsage,
ProjectID: project.ID, ProjectID: project.ID,
Egress: projectPrice.Egress.IntPart(), Egress: totalPrice.Egress.IntPart(),
SegmentCount: projectPrice.Segments.IntPart(), SegmentCount: totalPrice.Segments.IntPart(),
StorageGbHrs: projectPrice.Storage.IntPart(), StorageGbHrs: totalPrice.Storage.IntPart(),
}) })
} }
return charges, nil return charges, nil
} }
// GetProjectUsagePriceModel returns the project usage price model. // GetProjectUsagePriceModel returns the project usage price model for a partner name.
func (accounts *accounts) GetProjectUsagePriceModel() payments.ProjectUsagePriceModel { func (accounts *accounts) GetProjectUsagePriceModel(partner string) payments.ProjectUsagePriceModel {
if override, ok := accounts.service.usagePriceOverrides[partner]; ok {
return override
}
return accounts.service.usagePrices return accounts.service.usagePrices
} }

View File

@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -486,7 +487,12 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return true, nil return true, nil
} }
items := service.InvoiceItemsFromProjectRecord(projName, record) usages, err := service.usageDB.GetProjectTotalByPartner(ctx, record.ProjectID, service.partnerNames, record.PeriodStart, record.PeriodEnd)
if err != nil {
return false, err
}
items := service.InvoiceItemsFromProjectUsage(projName, usages)
for _, item := range items { for _, item := range items {
item.Currency = stripe.String(string(stripe.CurrencyUSD)) item.Currency = stripe.String(string(stripe.CurrencyUSD))
item.Customer = stripe.String(cusID) item.Customer = stripe.String(cusID)
@ -501,28 +507,50 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return false, nil return false, nil
} }
// InvoiceItemsFromProjectRecord calculates Stripe invoice item from project record. // InvoiceItemsFromProjectUsage calculates Stripe invoice item from project usage.
func (service *Service) InvoiceItemsFromProjectRecord(projName string, record ProjectRecord) (result []*stripe.InvoiceItemParams) { func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsages map[string]accounting.ProjectUsage) (result []*stripe.InvoiceItemParams) {
projectItem := &stripe.InvoiceItemParams{} var partners []string
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Storage (MB-Month)", projName)) if len(partnerUsages) == 0 {
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(record.Storage).IntPart()) partners = []string{""}
storagePrice, _ := service.usagePrices.StorageMBMonthCents.Float64() partnerUsages = map[string]accounting.ProjectUsage{"": {}}
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice) } else {
result = append(result, projectItem) for partner := range partnerUsages {
partners = append(partners, partner)
}
sort.Strings(partners)
}
projectItem = &stripe.InvoiceItemParams{} for _, partner := range partners {
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Egress Bandwidth (MB)", projName)) usage := partnerUsages[partner]
projectItem.Quantity = stripe.Int64(egressMBDecimal(record.Egress).IntPart()) priceModel := service.Accounts().GetProjectUsagePriceModel(partner)
egressPrice, _ := service.usagePrices.EgressMBCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice) prefix := "Project " + projName
result = append(result, projectItem) if partner != "" {
prefix += " (" + partner + ")"
}
projectItem := &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Storage (MB-Month)")
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(usage.Storage).IntPart())
storagePrice, _ := priceModel.StorageMBMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Egress Bandwidth (MB)")
projectItem.Quantity = stripe.Int64(egressMBDecimal(usage.Egress).IntPart())
egressPrice, _ := priceModel.EgressMBCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Fee (Segment-Month)")
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(usage.SegmentCount).IntPart())
segmentPrice, _ := priceModel.SegmentMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
result = append(result, projectItem)
}
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Fee (Segment-Month)", projName))
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(record.Segments).IntPart())
segmentPrice, _ := service.usagePrices.SegmentMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
result = append(result, projectItem)
service.log.Info("invoice items", zap.Any("result", result)) service.log.Info("invoice items", zap.Any("result", result))
return result return result
@ -780,11 +808,11 @@ func (price projectUsagePrice) TotalInt64() int64 {
} }
// calculateProjectUsagePrice calculate project usage price. // calculateProjectUsagePrice calculate project usage price.
func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64) projectUsagePrice { func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64, pricing payments.ProjectUsagePriceModel) projectUsagePrice {
return projectUsagePrice{ return projectUsagePrice{
Storage: service.usagePrices.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0), Storage: pricing.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0),
Egress: service.usagePrices.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0), Egress: pricing.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0),
Segments: service.usagePrices.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0), Segments: pricing.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0),
} }
} }

View File

@ -5,6 +5,8 @@ package stripecoinpayments_test
import ( import (
"context" "context"
"fmt"
"math"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -16,6 +18,7 @@ import (
"storj.io/common/currency" "storj.io/common/currency"
"storj.io/common/memory" "storj.io/common/memory"
"storj.io/common/pb" "storj.io/common/pb"
"storj.io/common/storj"
"storj.io/common/testcontext" "storj.io/common/testcontext"
"storj.io/common/testrand" "storj.io/common/testrand"
"storj.io/common/uuid" "storj.io/common/uuid"
@ -25,7 +28,9 @@ import (
"storj.io/storj/satellite/accounting" "storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/console" "storj.io/storj/satellite/console"
"storj.io/storj/satellite/metabase" "storj.io/storj/satellite/metabase"
"storj.io/storj/satellite/payments"
"storj.io/storj/satellite/payments/billing" "storj.io/storj/satellite/payments/billing"
"storj.io/storj/satellite/payments/paymentsconfig"
"storj.io/storj/satellite/payments/stripecoinpayments" "storj.io/storj/satellite/payments/stripecoinpayments"
) )
@ -222,67 +227,106 @@ func TestService_ProjectsWithMembers(t *testing.T) {
}) })
} }
func TestService_InvoiceItemsFromProjectRecord(t *testing.T) { func TestService_InvoiceItemsFromProjectUsage(t *testing.T) {
const (
projectName = "my-project"
partnerName = "partner"
noOverridePartnerName = "no-override"
hoursPerMonth = 24 * 30
bytesPerMegabyte = int64(memory.MB / memory.B)
byteHoursPerMBMonth = hoursPerMonth * bytesPerMegabyte
)
var (
defaultPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "1",
EgressTB: "2",
Segment: "3",
}
partnerPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "4",
EgressTB: "5",
Segment: "6",
}
)
defaultModel, err := defaultPrice.ToModel()
require.NoError(t, err)
partnerModel, err := partnerPrice.ToModel()
require.NoError(t, err)
testplanet.Run(t, testplanet.Config{ testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Payments.UsagePrice = defaultPrice
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
partnerName: partnerPrice,
})
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
satellite := planet.Satellites[0] usage := map[string]accounting.ProjectUsage{
"": {
// these numbers are fraction of cents, not of dollars. Storage: 10000000000, // Byte-hours
expectedStoragePrice := 0.001 Egress: 123 * memory.GB.Int64(), // Bytes
expectedEgressPrice := 0.0045 SegmentCount: 200000, // Segment-Hours
expectedSegmentPrice := 0.00022
type TestCase struct {
Storage float64
Egress int64
Segments float64
StorageQuantity int64
EgressQuantity int64
SegmentsQuantity int64
}
testCases := []TestCase{
{}, // all zeros
{
Storage: 10000000000, // Byte-Hours
// storage quantity is calculated to Megabyte-Months
// (10000000000 / 1000000) Byte-Hours to Megabytes-Hours
// round(10000 / 720) Megabytes-Hours to Megabyte-Months, 720 - hours in month
StorageQuantity: 14, // Megabyte-Months
}, },
{ partnerName: {
Egress: 134 * memory.GB.Int64(), // Bytes Storage: 20000000000,
// egress quantity is calculated to Megabytes Egress: 456 * memory.GB.Int64(),
// (134000000000 / 1000000) Bytes to Megabytes SegmentCount: 400000,
EgressQuantity: 134000, // Megabytes
}, },
{ noOverridePartnerName: {
Segments: 400000, // Segment-Hours Storage: 30000000000,
// object quantity is calculated to Segment-Months Egress: 789 * memory.GB.Int64(),
// round(400000 / 720) Segment-Hours to Segment-Months, 720 - hours in month SegmentCount: 600000,
SegmentsQuantity: 556, // Segment-Months
}, },
} }
for _, tc := range testCases { items := planet.Satellites[0].API.Payments.StripeService.InvoiceItemsFromProjectUsage(projectName, usage)
record := stripecoinpayments.ProjectRecord{ require.Len(t, items, len(usage)*3)
Storage: tc.Storage,
Egress: tc.Egress,
Segments: tc.Segments,
}
items := satellite.API.Payments.StripeService.InvoiceItemsFromProjectRecord("project name", record) for i, tt := range []struct {
name string
partner string
priceModel payments.ProjectUsagePriceModel
}{
{"default pricing - no partner", "", defaultModel},
{"default pricing - no override for partner", noOverridePartnerName, defaultModel},
{"partner pricing", partnerName, partnerModel},
} {
t.Run(tt.name, func(t *testing.T) {
prefix := "Project " + projectName
if tt.partner != "" {
prefix += " (" + tt.partner + ")"
}
require.Equal(t, tc.StorageQuantity, *items[0].Quantity) usage := usage[tt.partner]
require.Equal(t, expectedStoragePrice, *items[0].UnitAmountDecimal) expectedStorageQuantity := int64(math.Round(usage.Storage / float64(byteHoursPerMBMonth)))
expectedEgressQuantity := int64(math.Round(float64(usage.Egress) / float64(bytesPerMegabyte)))
expectedSegmentQuantity := int64(math.Round(usage.SegmentCount / hoursPerMonth))
require.Equal(t, tc.EgressQuantity, *items[1].Quantity) items := items[i*3 : (i*3)+3]
require.Equal(t, expectedEgressPrice, *items[1].UnitAmountDecimal) for _, item := range items {
require.NotNil(t, item)
}
require.Equal(t, tc.SegmentsQuantity, *items[2].Quantity) require.Equal(t, prefix+" - Segment Storage (MB-Month)", *items[0].Description)
require.Equal(t, expectedSegmentPrice, *items[2].UnitAmountDecimal) require.Equal(t, expectedStorageQuantity, *items[0].Quantity)
storage, _ := tt.priceModel.StorageMBMonthCents.Float64()
require.Equal(t, storage, *items[0].UnitAmountDecimal)
require.Equal(t, prefix+" - Egress Bandwidth (MB)", *items[1].Description)
require.Equal(t, expectedEgressQuantity, *items[1].Quantity)
egress, _ := tt.priceModel.EgressMBCents.Float64()
require.Equal(t, egress, *items[1].UnitAmountDecimal)
require.Equal(t, prefix+" - Segment Fee (Segment-Month)", *items[2].Description)
require.Equal(t, expectedSegmentQuantity, *items[2].Quantity)
segment, _ := tt.priceModel.SegmentMonthCents.Float64()
require.Equal(t, segment, *items[2].UnitAmountDecimal)
})
} }
}) })
} }
@ -491,6 +535,100 @@ func generateProjectStorage(ctx context.Context, tb testing.TB, db satellite.DB,
require.NoError(tb, err) require.NoError(tb, err)
} }
func TestProjectUsagePrice(t *testing.T) {
var (
defaultPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "1",
EgressTB: "2",
Segment: "3",
}
partnerName = "partner"
partnerPrice = paymentsconfig.ProjectUsagePrice{
StorageTB: "4",
EgressTB: "5",
Segment: "6",
}
)
defaultModel, err := defaultPrice.ToModel()
require.NoError(t, err)
partnerModel, err := partnerPrice.ToModel()
require.NoError(t, err)
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Payments.UsagePrice = defaultPrice
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
partnerName: partnerPrice,
})
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
sat := planet.Satellites[0]
// pick a specific date so that it doesn't fail if it's the last day of the month
// keep month + 1 because user needs to be created before calculation
period := time.Date(time.Now().Year(), time.Now().Month()+1, 20, 0, 0, 0, 0, time.UTC)
sat.API.Payments.StripeService.SetNow(func() time.Time {
return time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)
})
for i, tt := range []struct {
name string
userAgent []byte
expectedPrice payments.ProjectUsagePriceModel
}{
{"default pricing", nil, defaultModel},
{"default pricing - user agent is not valid partner name", []byte("invalid/v0.0"), defaultModel},
{"partner pricing - user agent is partner name", []byte(partnerName), partnerModel},
{"partner pricing - user agent prefixed with partner name", []byte(partnerName + " invalid/v0.0"), partnerModel},
} {
t.Run(tt.name, func(t *testing.T) {
user, err := sat.AddUser(ctx, console.CreateUser{
FullName: "Test User",
Email: fmt.Sprintf("user%d@mail.test", i),
UserAgent: tt.userAgent,
}, 1)
require.NoError(t, err)
project, err := sat.AddProject(ctx, user.ID, "testproject")
require.NoError(t, err)
bucket, err := sat.DB.Buckets().CreateBucket(ctx, storj.Bucket{
ID: testrand.UUID(),
Name: testrand.BucketName(),
ProjectID: project.ID,
UserAgent: tt.userAgent,
})
require.NoError(t, err)
err = sat.DB.Orders().UpdateBucketBandwidthSettle(ctx, project.ID, []byte(bucket.Name),
pb.PieceAction_GET, memory.TB.Int64(), 0, period)
require.NoError(t, err)
err = sat.API.Payments.StripeService.PrepareInvoiceProjectRecords(ctx, period)
require.NoError(t, err)
err = sat.API.Payments.StripeService.InvoiceApplyProjectRecords(ctx, period)
require.NoError(t, err)
cusID, err := sat.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID)
require.NoError(t, err)
items := getCustomerInvoiceItems(sat.API.Payments.StripeClient, cusID)
require.Len(t, items, 3)
storage, _ := tt.expectedPrice.StorageMBMonthCents.Float64()
require.Equal(t, storage, items[0].UnitAmountDecimal)
egress, _ := tt.expectedPrice.EgressMBCents.Float64()
require.Equal(t, egress, items[1].UnitAmountDecimal)
segment, _ := tt.expectedPrice.SegmentMonthCents.Float64()
require.Equal(t, segment, items[2].UnitAmountDecimal)
})
}
})
}
func TestPayInvoicesSkipDue(t *testing.T) { func TestPayInvoicesSkipDue(t *testing.T) {
testplanet.Run(t, testplanet.Config{ testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,

View File

@ -16,6 +16,7 @@ import (
"storj.io/common/memory" "storj.io/common/memory"
"storj.io/common/pb" "storj.io/common/pb"
"storj.io/common/useragent"
"storj.io/common/uuid" "storj.io/common/uuid"
"storj.io/private/dbutil" "storj.io/private/dbutil"
"storj.io/private/dbutil/pgutil" "storj.io/private/dbutil/pgutil"
@ -505,7 +506,21 @@ func (db *ProjectAccounting) GetProjectSegmentLimit(ctx context.Context, project
} }
// GetProjectTotal retrieves project usage for a given period. // GetProjectTotal retrieves project usage for a given period.
func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (usage *accounting.ProjectUsage, err error) { func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ *accounting.ProjectUsage, err error) {
defer mon.Task()(&ctx)(&err)
usages, err := db.GetProjectTotalByPartner(ctx, projectID, nil, since, before)
if err != nil {
return nil, err
}
if usage, ok := usages[""]; ok {
return &usage, nil
}
return &accounting.ProjectUsage{Since: since, Before: before}, nil
}
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]accounting.ProjectUsage, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
since = timeTruncateDown(since) since = timeTruncateDown(since)
bucketNames, err := db.getBucketsSinceAndBefore(ctx, projectID, since, before) bucketNames, err := db.getBucketsSinceAndBefore(ctx, projectID, since, before)
@ -531,16 +546,54 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
ORDER BY bucket_storage_tallies.interval_start DESC ORDER BY bucket_storage_tallies.interval_start DESC
`) `)
bucketsTallies := make(map[string][]*accounting.BucketStorageTally) totalEgressQuery := db.db.Rebind(`
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
bucket_name = ? AND
project_id = ? AND
interval_start >= ? AND
interval_start < ? AND
action = ?;
`)
usages = make(map[string]accounting.ProjectUsage)
for _, bucket := range bucketNames { for _, bucket := range bucketNames {
storageTallies := make([]*accounting.BucketStorageTally, 0) userAgentRow, err := db.db.Get_BucketMetainfo_UserAgent_By_ProjectId_And_Name(ctx,
dbx.BucketMetainfo_ProjectId(projectID[:]),
dbx.BucketMetainfo_Name([]byte(bucket)))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
var partner string
if userAgentRow != nil && userAgentRow.UserAgent != nil {
entries, err := useragent.ParseEntries(userAgentRow.UserAgent)
if err != nil {
return nil, err
}
for _, iterPartner := range partnerNames {
if entries[0].Product == iterPartner {
partner = iterPartner
break
}
}
}
if _, ok := usages[partner]; !ok {
usages[partner] = accounting.ProjectUsage{Since: since, Before: before}
}
usage := usages[partner]
storageTalliesRows, err := db.db.QueryContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before) storageTalliesRows, err := db.db.QueryContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// generating tallies for each bucket name.
var prevTally *accounting.BucketStorageTally
for storageTalliesRows.Next() { for storageTalliesRows.Next() {
tally := accounting.BucketStorageTally{} tally := accounting.BucketStorageTally{}
@ -553,8 +606,17 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
tally.TotalBytes = inline + remote tally.TotalBytes = inline + remote
} }
tally.BucketName = bucket if prevTally == nil {
storageTallies = append(storageTallies, &tally) prevTally = &tally
continue
}
hours := prevTally.IntervalStart.Sub(tally.IntervalStart).Hours()
usage.Storage += memory.Size(tally.TotalBytes).Float64() * hours
usage.SegmentCount += float64(tally.TotalSegmentCount) * hours
usage.ObjectCount += float64(tally.ObjectCount) * hours
prevTally = &tally
} }
err = errs.Combine(storageTalliesRows.Err(), storageTalliesRows.Close()) err = errs.Combine(storageTalliesRows.Err(), storageTalliesRows.Close())
@ -562,53 +624,21 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
return nil, err return nil, err
} }
bucketsTallies[bucket] = storageTallies totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, []byte(bucket), projectID[:], since, before, pb.PieceAction_GET)
} if err != nil {
return nil, err
totalEgress, err := db.getTotalEgress(ctx, projectID, since, before)
if err != nil {
return nil, err
}
usage = new(accounting.ProjectUsage)
usage.Egress = memory.Size(totalEgress).Int64()
// sum up storage, objects, and segments
for _, tallies := range bucketsTallies {
for i := len(tallies) - 1; i > 0; i-- {
current := (tallies)[i]
hours := (tallies)[i-1].IntervalStart.Sub(current.IntervalStart).Hours()
usage.Storage += memory.Size(current.Bytes()).Float64() * hours
usage.SegmentCount += float64(current.TotalSegmentCount) * hours
usage.ObjectCount += float64(current.ObjectCount) * hours
} }
var egress int64
if err = totalEgressRow.Scan(&egress); err != nil {
return nil, err
}
usage.Egress += egress
usages[partner] = usage
} }
usage.Since = since return usages, nil
usage.Before = before
return usage, nil
}
// getTotalEgress returns total egress (settled + inline) of each bucket_bandwidth_rollup
// in selected time period, project id.
// only process PieceAction_GET.
func (db *ProjectAccounting) getTotalEgress(ctx context.Context, projectID uuid.UUID, since, before time.Time) (totalEgress int64, err error) {
totalEgressQuery := db.db.Rebind(`
SELECT
COALESCE(SUM(settled) + SUM(inline), 0)
FROM
bucket_bandwidth_rollups
WHERE
project_id = ? AND
interval_start >= ? AND
interval_start < ? AND
action = ?;
`)
totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, projectID[:], since, before, pb.PieceAction_GET)
err = totalEgressRow.Scan(&totalEgress)
return totalEgress, err
} }
// GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period. // GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period.

View File

@ -11,8 +11,10 @@ import (
"storj.io/common/memory" "storj.io/common/memory"
"storj.io/common/pb" "storj.io/common/pb"
"storj.io/common/storj"
"storj.io/common/testcontext" "storj.io/common/testcontext"
"storj.io/common/testrand" "storj.io/common/testrand"
"storj.io/common/uuid"
"storj.io/storj/private/testplanet" "storj.io/storj/private/testplanet"
"storj.io/storj/satellite/accounting" "storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/console" "storj.io/storj/satellite/console"
@ -185,14 +187,7 @@ func Test_GetProjectTotal(t *testing.T) {
// The 3rd tally is only present to prevent CreateStorageTally from skipping the 2nd. // The 3rd tally is only present to prevent CreateStorageTally from skipping the 2nd.
var tallies []accounting.BucketStorageTally var tallies []accounting.BucketStorageTally
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
tally := accounting.BucketStorageTally{ tally := randTally(bucketName, projectID, time.Time{}.Add(time.Duration(i)*time.Hour))
BucketName: bucketName,
ProjectID: projectID,
IntervalStart: time.Time{}.Add(time.Duration(i) * time.Hour),
TotalBytes: int64(testrand.Intn(1000)),
ObjectCount: int64(testrand.Intn(1000)),
TotalSegmentCount: int64(testrand.Intn(1000)),
}
tallies = append(tallies, tally) tallies = append(tallies, tally)
require.NoError(t, db.ProjectAccounting().CreateStorageTally(ctx, tally)) require.NoError(t, db.ProjectAccounting().CreateStorageTally(ctx, tally))
} }
@ -200,14 +195,7 @@ func Test_GetProjectTotal(t *testing.T) {
var rollups []orders.BucketBandwidthRollup var rollups []orders.BucketBandwidthRollup
var expectedEgress int64 var expectedEgress int64
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
rollup := orders.BucketBandwidthRollup{ rollup := randRollup(bucketName, projectID, tallies[i].IntervalStart)
ProjectID: projectID,
BucketName: bucketName,
Action: pb.PieceAction_GET,
IntervalStart: tallies[i].IntervalStart,
Inline: int64(testrand.Intn(1000)),
Settled: int64(testrand.Intn(1000)),
}
rollups = append(rollups, rollup) rollups = append(rollups, rollup)
expectedEgress += rollup.Inline + rollup.Settled expectedEgress += rollup.Inline + rollup.Settled
} }
@ -245,3 +233,182 @@ func Test_GetProjectTotal(t *testing.T) {
}, },
) )
} }
func Test_GetProjectTotalByPartner(t *testing.T) {
const (
epsilon = 1e-8
usagePeriod = time.Hour
tallyRollupCount = 2
)
since := time.Time{}
before := since.Add(2 * usagePeriod)
testplanet.Run(t, testplanet.Config{SatelliteCount: 1, StorageNodeCount: 1},
func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
sat := planet.Satellites[0]
user, err := sat.AddUser(ctx, console.CreateUser{
FullName: "Test User",
Email: "user@mail.test",
}, 1)
require.NoError(t, err)
project, err := sat.AddProject(ctx, user.ID, "testproject")
require.NoError(t, err)
type expectedTotal struct {
storage float64
segments float64
objects float64
egress int64
}
expectedTotals := make(map[string]expectedTotal)
var beforeTotal expectedTotal
requireTotal := func(t *testing.T, expected expectedTotal, expectedSince, expectedBefore time.Time, actual accounting.ProjectUsage) {
require.InDelta(t, expected.storage, actual.Storage, epsilon)
require.InDelta(t, expected.segments, actual.SegmentCount, epsilon)
require.InDelta(t, expected.objects, actual.ObjectCount, epsilon)
require.Equal(t, expected.egress, actual.Egress)
require.Equal(t, expectedSince, actual.Since)
require.Equal(t, expectedBefore, actual.Before)
}
partnerNames := []string{"", "partner1", "partner2"}
for _, name := range partnerNames {
total := expectedTotal{}
bucket := storj.Bucket{
ID: testrand.UUID(),
Name: testrand.BucketName(),
ProjectID: project.ID,
}
if name != "" {
bucket.UserAgent = []byte(name)
}
_, err := sat.DB.Buckets().CreateBucket(ctx, bucket)
require.NoError(t, err)
// We use multiple tallies and rollups to ensure that
// GetProjectTotalByPartner is capable of summing them.
for i := 0; i <= tallyRollupCount; i++ {
tally := randTally(bucket.Name, project.ID, since.Add(time.Duration(i)*usagePeriod/tallyRollupCount))
require.NoError(t, sat.DB.ProjectAccounting().CreateStorageTally(ctx, tally))
// The last tally's usage data is unused.
usageHours := (usagePeriod / tallyRollupCount).Hours()
if i < tallyRollupCount {
total.storage += float64(tally.Bytes()) * usageHours
total.objects += float64(tally.ObjectCount) * usageHours
total.segments += float64(tally.TotalSegmentCount) * usageHours
}
if i < tallyRollupCount-1 {
beforeTotal.storage += float64(tally.Bytes()) * usageHours
beforeTotal.objects += float64(tally.ObjectCount) * usageHours
beforeTotal.segments += float64(tally.TotalSegmentCount) * usageHours
}
}
var rollups []orders.BucketBandwidthRollup
for i := 0; i < tallyRollupCount; i++ {
rollup := randRollup(bucket.Name, project.ID, since.Add(time.Duration(i)*usagePeriod/tallyRollupCount))
rollups = append(rollups, rollup)
total.egress += rollup.Inline + rollup.Settled
if i < tallyRollupCount {
beforeTotal.egress += rollup.Inline + rollup.Settled
}
}
require.NoError(t, sat.DB.Orders().UpdateBandwidthBatch(ctx, rollups))
expectedTotals[name] = total
}
t.Run("sum all partner usages", func(t *testing.T) {
ctx := testcontext.New(t)
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
require.NoError(t, err)
require.Len(t, usages, 1)
require.Contains(t, usages, "")
var summedTotal expectedTotal
for _, total := range expectedTotals {
summedTotal.storage += total.storage
summedTotal.segments += total.segments
summedTotal.objects += total.objects
summedTotal.egress += total.egress
}
requireTotal(t, summedTotal, since, before, usages[""])
})
t.Run("individual partner usages", func(t *testing.T) {
ctx := testcontext.New(t)
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, partnerNames, since, before)
require.NoError(t, err)
require.Len(t, usages, len(expectedTotals))
for _, name := range partnerNames {
require.Contains(t, usages, name)
}
for partner, usage := range usages {
requireTotal(t, expectedTotals[partner], since, before, usage)
}
})
t.Run("select one partner usage and sum remaining usages", func(t *testing.T) {
ctx := testcontext.New(t)
partner := partnerNames[len(partnerNames)-1]
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, []string{partner}, since, before)
require.NoError(t, err)
require.Len(t, usages, 2)
require.Contains(t, usages, "")
require.Contains(t, usages, partner)
var summedTotal expectedTotal
for _, partner := range partnerNames[:len(partnerNames)-1] {
summedTotal.storage += expectedTotals[partner].storage
summedTotal.segments += expectedTotals[partner].segments
summedTotal.objects += expectedTotals[partner].objects
summedTotal.egress += expectedTotals[partner].egress
}
requireTotal(t, expectedTotals[partner], since, before, usages[partner])
requireTotal(t, summedTotal, since, before, usages[""])
})
t.Run("ensure the 'before' arg is exclusive", func(t *testing.T) {
ctx := testcontext.New(t)
before := since.Add(usagePeriod)
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
require.NoError(t, err)
require.Len(t, usages, 1)
require.Contains(t, usages, "")
requireTotal(t, beforeTotal, since, before, usages[""])
})
},
)
}
func randTally(bucketName string, projectID uuid.UUID, intervalStart time.Time) accounting.BucketStorageTally {
return accounting.BucketStorageTally{
BucketName: bucketName,
ProjectID: projectID,
IntervalStart: intervalStart,
TotalBytes: int64(testrand.Intn(1000)),
ObjectCount: int64(testrand.Intn(1000)),
TotalSegmentCount: int64(testrand.Intn(1000)),
}
}
func randRollup(bucketName string, projectID uuid.UUID, intervalStart time.Time) orders.BucketBandwidthRollup {
return orders.BucketBandwidthRollup{
ProjectID: projectID,
BucketName: bucketName,
IntervalStart: intervalStart,
Action: pb.PieceAction_GET,
Inline: int64(testrand.Intn(1000)),
Settled: int64(testrand.Intn(1000)),
}
}