satellite/payments/stripecoinpayments: undo price override removal
Commit fb59974
disabled usage price overrides because of a failing
test. This change reenables it while resolving the issue that caused
the test to fail.
The previous version of the test passed Gerrit verification and was
merged, but it failed for the primary Jenkins pipeline after merge.
This is due to a difference in how the Jenkins build runs Cockroach
and Postgres for each pipeline.
This commit rewrites the test to be safe for concurrent execution by
ensuring any mutable variables are defined within each test so that
shared state across tests is reduced.
Change-Id: Ia4566c9cd2d698afdb2caa4b7e2808b17e18de4e
This commit is contained in:
parent
17ec326fd4
commit
cbbd5ab1ef
@ -236,6 +236,9 @@ type ProjectAccounting interface {
|
||||
GetProjectLimits(ctx context.Context, projectID uuid.UUID) (ProjectLimits, error)
|
||||
// GetProjectTotal returns project usage summary for specified period of time.
|
||||
GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (*ProjectUsage, error)
|
||||
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
|
||||
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
|
||||
GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]ProjectUsage, err error)
|
||||
// GetProjectObjectsSegments returns project objects and segments for specified period of time.
|
||||
GetProjectObjectsSegments(ctx context.Context, projectID uuid.UUID) (*ProjectObjectsSegments, error)
|
||||
// GetBucketUsageRollups returns usage rollup per each bucket for specified period of time.
|
||||
|
@ -3139,7 +3139,12 @@ func (payment Payments) Purchase(ctx context.Context, price int64, desc string,
|
||||
func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
model := payment.service.accounts.GetProjectUsagePriceModel()
|
||||
user, err := GetUser(ctx)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
model := payment.service.accounts.GetProjectUsagePriceModel(string(user.UserAgent))
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
|
@ -29,8 +29,8 @@ type Accounts interface {
|
||||
// ProjectCharges returns how much money current user will be charged for each project.
|
||||
ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error)
|
||||
|
||||
// GetProjectUsagePriceModel returns the project usage price model.
|
||||
GetProjectUsagePriceModel() ProjectUsagePriceModel
|
||||
// GetProjectUsagePriceModel returns the project usage price model for a partner name.
|
||||
GetProjectUsagePriceModel(partner string) ProjectUsagePriceModel
|
||||
|
||||
// CheckProjectInvoicingStatus returns error if for the given project there are outstanding project records and/or usage
|
||||
// which have not been applied/invoiced yet (meaning sent over to stripe).
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/accounting"
|
||||
"storj.io/storj/satellite/payments"
|
||||
)
|
||||
|
||||
@ -120,28 +121,47 @@ func (accounts *accounts) ProjectCharges(ctx context.Context, userID uuid.UUID,
|
||||
}
|
||||
|
||||
for _, project := range projects {
|
||||
usage, err := accounts.service.usageDB.GetProjectTotal(ctx, project.ID, since, before)
|
||||
totalUsage := accounting.ProjectUsage{Since: since, Before: before}
|
||||
|
||||
usages, err := accounts.service.usageDB.GetProjectTotalByPartner(ctx, project.ID, accounts.service.partnerNames, since, before)
|
||||
if err != nil {
|
||||
return charges, Error.Wrap(err)
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
projectPrice := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount)
|
||||
var totalPrice projectUsagePrice
|
||||
|
||||
for partner, usage := range usages {
|
||||
priceModel := accounts.GetProjectUsagePriceModel(partner)
|
||||
price := accounts.service.calculateProjectUsagePrice(usage.Egress, usage.Storage, usage.SegmentCount, priceModel)
|
||||
|
||||
totalPrice.Egress = totalPrice.Egress.Add(price.Egress)
|
||||
totalPrice.Segments = totalPrice.Segments.Add(price.Segments)
|
||||
totalPrice.Storage = totalPrice.Storage.Add(price.Storage)
|
||||
|
||||
totalUsage.Egress += usage.Egress
|
||||
totalUsage.ObjectCount += usage.ObjectCount
|
||||
totalUsage.SegmentCount += usage.SegmentCount
|
||||
totalUsage.Storage += usage.Storage
|
||||
}
|
||||
|
||||
charges = append(charges, payments.ProjectCharge{
|
||||
ProjectUsage: *usage,
|
||||
ProjectUsage: totalUsage,
|
||||
|
||||
ProjectID: project.ID,
|
||||
Egress: projectPrice.Egress.IntPart(),
|
||||
SegmentCount: projectPrice.Segments.IntPart(),
|
||||
StorageGbHrs: projectPrice.Storage.IntPart(),
|
||||
Egress: totalPrice.Egress.IntPart(),
|
||||
SegmentCount: totalPrice.Segments.IntPart(),
|
||||
StorageGbHrs: totalPrice.Storage.IntPart(),
|
||||
})
|
||||
}
|
||||
|
||||
return charges, nil
|
||||
}
|
||||
|
||||
// GetProjectUsagePriceModel returns the project usage price model.
|
||||
func (accounts *accounts) GetProjectUsagePriceModel() payments.ProjectUsagePriceModel {
|
||||
// GetProjectUsagePriceModel returns the project usage price model for a partner name.
|
||||
func (accounts *accounts) GetProjectUsagePriceModel(partner string) payments.ProjectUsagePriceModel {
|
||||
if override, ok := accounts.service.usagePriceOverrides[partner]; ok {
|
||||
return override
|
||||
}
|
||||
return accounts.service.usagePrices
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -486,7 +487,12 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
|
||||
return true, nil
|
||||
}
|
||||
|
||||
items := service.InvoiceItemsFromProjectRecord(projName, record)
|
||||
usages, err := service.usageDB.GetProjectTotalByPartner(ctx, record.ProjectID, service.partnerNames, record.PeriodStart, record.PeriodEnd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
items := service.InvoiceItemsFromProjectUsage(projName, usages)
|
||||
for _, item := range items {
|
||||
item.Currency = stripe.String(string(stripe.CurrencyUSD))
|
||||
item.Customer = stripe.String(cusID)
|
||||
@ -501,28 +507,50 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// InvoiceItemsFromProjectRecord calculates Stripe invoice item from project record.
|
||||
func (service *Service) InvoiceItemsFromProjectRecord(projName string, record ProjectRecord) (result []*stripe.InvoiceItemParams) {
|
||||
projectItem := &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Storage (MB-Month)", projName))
|
||||
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(record.Storage).IntPart())
|
||||
storagePrice, _ := service.usagePrices.StorageMBMonthCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
|
||||
result = append(result, projectItem)
|
||||
// InvoiceItemsFromProjectUsage calculates Stripe invoice item from project usage.
|
||||
func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsages map[string]accounting.ProjectUsage) (result []*stripe.InvoiceItemParams) {
|
||||
var partners []string
|
||||
if len(partnerUsages) == 0 {
|
||||
partners = []string{""}
|
||||
partnerUsages = map[string]accounting.ProjectUsage{"": {}}
|
||||
} else {
|
||||
for partner := range partnerUsages {
|
||||
partners = append(partners, partner)
|
||||
}
|
||||
sort.Strings(partners)
|
||||
}
|
||||
|
||||
projectItem = &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Egress Bandwidth (MB)", projName))
|
||||
projectItem.Quantity = stripe.Int64(egressMBDecimal(record.Egress).IntPart())
|
||||
egressPrice, _ := service.usagePrices.EgressMBCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
|
||||
result = append(result, projectItem)
|
||||
for _, partner := range partners {
|
||||
usage := partnerUsages[partner]
|
||||
priceModel := service.Accounts().GetProjectUsagePriceModel(partner)
|
||||
|
||||
prefix := "Project " + projName
|
||||
if partner != "" {
|
||||
prefix += " (" + partner + ")"
|
||||
}
|
||||
|
||||
projectItem := &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(prefix + " - Segment Storage (MB-Month)")
|
||||
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(usage.Storage).IntPart())
|
||||
storagePrice, _ := priceModel.StorageMBMonthCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
|
||||
result = append(result, projectItem)
|
||||
|
||||
projectItem = &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(prefix + " - Egress Bandwidth (MB)")
|
||||
projectItem.Quantity = stripe.Int64(egressMBDecimal(usage.Egress).IntPart())
|
||||
egressPrice, _ := priceModel.EgressMBCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
|
||||
result = append(result, projectItem)
|
||||
|
||||
projectItem = &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(prefix + " - Segment Fee (Segment-Month)")
|
||||
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(usage.SegmentCount).IntPart())
|
||||
segmentPrice, _ := priceModel.SegmentMonthCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
|
||||
result = append(result, projectItem)
|
||||
}
|
||||
|
||||
projectItem = &stripe.InvoiceItemParams{}
|
||||
projectItem.Description = stripe.String(fmt.Sprintf("Project %s - Segment Fee (Segment-Month)", projName))
|
||||
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(record.Segments).IntPart())
|
||||
segmentPrice, _ := service.usagePrices.SegmentMonthCents.Float64()
|
||||
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
|
||||
result = append(result, projectItem)
|
||||
service.log.Info("invoice items", zap.Any("result", result))
|
||||
|
||||
return result
|
||||
@ -780,11 +808,11 @@ func (price projectUsagePrice) TotalInt64() int64 {
|
||||
}
|
||||
|
||||
// calculateProjectUsagePrice calculate project usage price.
|
||||
func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64) projectUsagePrice {
|
||||
func (service *Service) calculateProjectUsagePrice(egress int64, storage, segments float64, pricing payments.ProjectUsagePriceModel) projectUsagePrice {
|
||||
return projectUsagePrice{
|
||||
Storage: service.usagePrices.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0),
|
||||
Egress: service.usagePrices.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0),
|
||||
Segments: service.usagePrices.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0),
|
||||
Storage: pricing.StorageMBMonthCents.Mul(storageMBMonthDecimal(storage)).Round(0),
|
||||
Egress: pricing.EgressMBCents.Mul(egressMBDecimal(egress)).Round(0),
|
||||
Segments: pricing.SegmentMonthCents.Mul(segmentMonthDecimal(segments)).Round(0),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,8 @@ package stripecoinpayments_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
@ -16,6 +18,7 @@ import (
|
||||
"storj.io/common/currency"
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/common/uuid"
|
||||
@ -25,7 +28,9 @@ import (
|
||||
"storj.io/storj/satellite/accounting"
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/payments"
|
||||
"storj.io/storj/satellite/payments/billing"
|
||||
"storj.io/storj/satellite/payments/paymentsconfig"
|
||||
"storj.io/storj/satellite/payments/stripecoinpayments"
|
||||
)
|
||||
|
||||
@ -222,67 +227,106 @@ func TestService_ProjectsWithMembers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_InvoiceItemsFromProjectRecord(t *testing.T) {
|
||||
func TestService_InvoiceItemsFromProjectUsage(t *testing.T) {
|
||||
const (
|
||||
projectName = "my-project"
|
||||
partnerName = "partner"
|
||||
noOverridePartnerName = "no-override"
|
||||
|
||||
hoursPerMonth = 24 * 30
|
||||
bytesPerMegabyte = int64(memory.MB / memory.B)
|
||||
byteHoursPerMBMonth = hoursPerMonth * bytesPerMegabyte
|
||||
)
|
||||
|
||||
var (
|
||||
defaultPrice = paymentsconfig.ProjectUsagePrice{
|
||||
StorageTB: "1",
|
||||
EgressTB: "2",
|
||||
Segment: "3",
|
||||
}
|
||||
partnerPrice = paymentsconfig.ProjectUsagePrice{
|
||||
StorageTB: "4",
|
||||
EgressTB: "5",
|
||||
Segment: "6",
|
||||
}
|
||||
)
|
||||
defaultModel, err := defaultPrice.ToModel()
|
||||
require.NoError(t, err)
|
||||
partnerModel, err := partnerPrice.ToModel()
|
||||
require.NoError(t, err)
|
||||
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Payments.UsagePrice = defaultPrice
|
||||
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
|
||||
partnerName: partnerPrice,
|
||||
})
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
satellite := planet.Satellites[0]
|
||||
|
||||
// these numbers are fraction of cents, not of dollars.
|
||||
expectedStoragePrice := 0.001
|
||||
expectedEgressPrice := 0.0045
|
||||
expectedSegmentPrice := 0.00022
|
||||
|
||||
type TestCase struct {
|
||||
Storage float64
|
||||
Egress int64
|
||||
Segments float64
|
||||
|
||||
StorageQuantity int64
|
||||
EgressQuantity int64
|
||||
SegmentsQuantity int64
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
{}, // all zeros
|
||||
{
|
||||
Storage: 10000000000, // Byte-Hours
|
||||
// storage quantity is calculated to Megabyte-Months
|
||||
// (10000000000 / 1000000) Byte-Hours to Megabytes-Hours
|
||||
// round(10000 / 720) Megabytes-Hours to Megabyte-Months, 720 - hours in month
|
||||
StorageQuantity: 14, // Megabyte-Months
|
||||
usage := map[string]accounting.ProjectUsage{
|
||||
"": {
|
||||
Storage: 10000000000, // Byte-hours
|
||||
Egress: 123 * memory.GB.Int64(), // Bytes
|
||||
SegmentCount: 200000, // Segment-Hours
|
||||
},
|
||||
{
|
||||
Egress: 134 * memory.GB.Int64(), // Bytes
|
||||
// egress quantity is calculated to Megabytes
|
||||
// (134000000000 / 1000000) Bytes to Megabytes
|
||||
EgressQuantity: 134000, // Megabytes
|
||||
partnerName: {
|
||||
Storage: 20000000000,
|
||||
Egress: 456 * memory.GB.Int64(),
|
||||
SegmentCount: 400000,
|
||||
},
|
||||
{
|
||||
Segments: 400000, // Segment-Hours
|
||||
// object quantity is calculated to Segment-Months
|
||||
// round(400000 / 720) Segment-Hours to Segment-Months, 720 - hours in month
|
||||
SegmentsQuantity: 556, // Segment-Months
|
||||
noOverridePartnerName: {
|
||||
Storage: 30000000000,
|
||||
Egress: 789 * memory.GB.Int64(),
|
||||
SegmentCount: 600000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
record := stripecoinpayments.ProjectRecord{
|
||||
Storage: tc.Storage,
|
||||
Egress: tc.Egress,
|
||||
Segments: tc.Segments,
|
||||
}
|
||||
items := planet.Satellites[0].API.Payments.StripeService.InvoiceItemsFromProjectUsage(projectName, usage)
|
||||
require.Len(t, items, len(usage)*3)
|
||||
|
||||
items := satellite.API.Payments.StripeService.InvoiceItemsFromProjectRecord("project name", record)
|
||||
for i, tt := range []struct {
|
||||
name string
|
||||
partner string
|
||||
priceModel payments.ProjectUsagePriceModel
|
||||
}{
|
||||
{"default pricing - no partner", "", defaultModel},
|
||||
{"default pricing - no override for partner", noOverridePartnerName, defaultModel},
|
||||
{"partner pricing", partnerName, partnerModel},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefix := "Project " + projectName
|
||||
if tt.partner != "" {
|
||||
prefix += " (" + tt.partner + ")"
|
||||
}
|
||||
|
||||
require.Equal(t, tc.StorageQuantity, *items[0].Quantity)
|
||||
require.Equal(t, expectedStoragePrice, *items[0].UnitAmountDecimal)
|
||||
usage := usage[tt.partner]
|
||||
expectedStorageQuantity := int64(math.Round(usage.Storage / float64(byteHoursPerMBMonth)))
|
||||
expectedEgressQuantity := int64(math.Round(float64(usage.Egress) / float64(bytesPerMegabyte)))
|
||||
expectedSegmentQuantity := int64(math.Round(usage.SegmentCount / hoursPerMonth))
|
||||
|
||||
require.Equal(t, tc.EgressQuantity, *items[1].Quantity)
|
||||
require.Equal(t, expectedEgressPrice, *items[1].UnitAmountDecimal)
|
||||
items := items[i*3 : (i*3)+3]
|
||||
for _, item := range items {
|
||||
require.NotNil(t, item)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.SegmentsQuantity, *items[2].Quantity)
|
||||
require.Equal(t, expectedSegmentPrice, *items[2].UnitAmountDecimal)
|
||||
require.Equal(t, prefix+" - Segment Storage (MB-Month)", *items[0].Description)
|
||||
require.Equal(t, expectedStorageQuantity, *items[0].Quantity)
|
||||
storage, _ := tt.priceModel.StorageMBMonthCents.Float64()
|
||||
require.Equal(t, storage, *items[0].UnitAmountDecimal)
|
||||
|
||||
require.Equal(t, prefix+" - Egress Bandwidth (MB)", *items[1].Description)
|
||||
require.Equal(t, expectedEgressQuantity, *items[1].Quantity)
|
||||
egress, _ := tt.priceModel.EgressMBCents.Float64()
|
||||
require.Equal(t, egress, *items[1].UnitAmountDecimal)
|
||||
|
||||
require.Equal(t, prefix+" - Segment Fee (Segment-Month)", *items[2].Description)
|
||||
require.Equal(t, expectedSegmentQuantity, *items[2].Quantity)
|
||||
segment, _ := tt.priceModel.SegmentMonthCents.Float64()
|
||||
require.Equal(t, segment, *items[2].UnitAmountDecimal)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -491,6 +535,100 @@ func generateProjectStorage(ctx context.Context, tb testing.TB, db satellite.DB,
|
||||
require.NoError(tb, err)
|
||||
}
|
||||
|
||||
func TestProjectUsagePrice(t *testing.T) {
|
||||
var (
|
||||
defaultPrice = paymentsconfig.ProjectUsagePrice{
|
||||
StorageTB: "1",
|
||||
EgressTB: "2",
|
||||
Segment: "3",
|
||||
}
|
||||
partnerName = "partner"
|
||||
partnerPrice = paymentsconfig.ProjectUsagePrice{
|
||||
StorageTB: "4",
|
||||
EgressTB: "5",
|
||||
Segment: "6",
|
||||
}
|
||||
)
|
||||
defaultModel, err := defaultPrice.ToModel()
|
||||
require.NoError(t, err)
|
||||
partnerModel, err := partnerPrice.ToModel()
|
||||
require.NoError(t, err)
|
||||
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Payments.UsagePrice = defaultPrice
|
||||
config.Payments.UsagePriceOverrides.SetMap(map[string]paymentsconfig.ProjectUsagePrice{
|
||||
partnerName: partnerPrice,
|
||||
})
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
sat := planet.Satellites[0]
|
||||
|
||||
// pick a specific date so that it doesn't fail if it's the last day of the month
|
||||
// keep month + 1 because user needs to be created before calculation
|
||||
period := time.Date(time.Now().Year(), time.Now().Month()+1, 20, 0, 0, 0, 0, time.UTC)
|
||||
sat.API.Payments.StripeService.SetNow(func() time.Time {
|
||||
return time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)
|
||||
})
|
||||
|
||||
for i, tt := range []struct {
|
||||
name string
|
||||
userAgent []byte
|
||||
expectedPrice payments.ProjectUsagePriceModel
|
||||
}{
|
||||
{"default pricing", nil, defaultModel},
|
||||
{"default pricing - user agent is not valid partner name", []byte("invalid/v0.0"), defaultModel},
|
||||
{"partner pricing - user agent is partner name", []byte(partnerName), partnerModel},
|
||||
{"partner pricing - user agent prefixed with partner name", []byte(partnerName + " invalid/v0.0"), partnerModel},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||
FullName: "Test User",
|
||||
Email: fmt.Sprintf("user%d@mail.test", i),
|
||||
UserAgent: tt.userAgent,
|
||||
}, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
project, err := sat.AddProject(ctx, user.ID, "testproject")
|
||||
require.NoError(t, err)
|
||||
|
||||
bucket, err := sat.DB.Buckets().CreateBucket(ctx, storj.Bucket{
|
||||
ID: testrand.UUID(),
|
||||
Name: testrand.BucketName(),
|
||||
ProjectID: project.ID,
|
||||
UserAgent: tt.userAgent,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sat.DB.Orders().UpdateBucketBandwidthSettle(ctx, project.ID, []byte(bucket.Name),
|
||||
pb.PieceAction_GET, memory.TB.Int64(), 0, period)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sat.API.Payments.StripeService.PrepareInvoiceProjectRecords(ctx, period)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sat.API.Payments.StripeService.InvoiceApplyProjectRecords(ctx, period)
|
||||
require.NoError(t, err)
|
||||
|
||||
cusID, err := sat.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
items := getCustomerInvoiceItems(sat.API.Payments.StripeClient, cusID)
|
||||
require.Len(t, items, 3)
|
||||
storage, _ := tt.expectedPrice.StorageMBMonthCents.Float64()
|
||||
require.Equal(t, storage, items[0].UnitAmountDecimal)
|
||||
egress, _ := tt.expectedPrice.EgressMBCents.Float64()
|
||||
require.Equal(t, egress, items[1].UnitAmountDecimal)
|
||||
segment, _ := tt.expectedPrice.SegmentMonthCents.Float64()
|
||||
require.Equal(t, segment, items[2].UnitAmountDecimal)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayInvoicesSkipDue(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/useragent"
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/private/dbutil"
|
||||
"storj.io/private/dbutil/pgutil"
|
||||
@ -505,7 +506,21 @@ func (db *ProjectAccounting) GetProjectSegmentLimit(ctx context.Context, project
|
||||
}
|
||||
|
||||
// GetProjectTotal retrieves project usage for a given period.
|
||||
func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (usage *accounting.ProjectUsage, err error) {
|
||||
func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ *accounting.ProjectUsage, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
usages, err := db.GetProjectTotalByPartner(ctx, projectID, nil, since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if usage, ok := usages[""]; ok {
|
||||
return &usage, nil
|
||||
}
|
||||
return &accounting.ProjectUsage{Since: since, Before: before}, nil
|
||||
}
|
||||
|
||||
// GetProjectTotalByPartner retrieves project usage for a given period categorized by partner name.
|
||||
// Unpartnered usage or usage for a partner not present in partnerNames is mapped to the empty string.
|
||||
func (db *ProjectAccounting) GetProjectTotalByPartner(ctx context.Context, projectID uuid.UUID, partnerNames []string, since, before time.Time) (usages map[string]accounting.ProjectUsage, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
since = timeTruncateDown(since)
|
||||
bucketNames, err := db.getBucketsSinceAndBefore(ctx, projectID, since, before)
|
||||
@ -531,16 +546,54 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
|
||||
ORDER BY bucket_storage_tallies.interval_start DESC
|
||||
`)
|
||||
|
||||
bucketsTallies := make(map[string][]*accounting.BucketStorageTally)
|
||||
totalEgressQuery := db.db.Rebind(`
|
||||
SELECT
|
||||
COALESCE(SUM(settled) + SUM(inline), 0)
|
||||
FROM
|
||||
bucket_bandwidth_rollups
|
||||
WHERE
|
||||
bucket_name = ? AND
|
||||
project_id = ? AND
|
||||
interval_start >= ? AND
|
||||
interval_start < ? AND
|
||||
action = ?;
|
||||
`)
|
||||
|
||||
usages = make(map[string]accounting.ProjectUsage)
|
||||
|
||||
for _, bucket := range bucketNames {
|
||||
storageTallies := make([]*accounting.BucketStorageTally, 0)
|
||||
userAgentRow, err := db.db.Get_BucketMetainfo_UserAgent_By_ProjectId_And_Name(ctx,
|
||||
dbx.BucketMetainfo_ProjectId(projectID[:]),
|
||||
dbx.BucketMetainfo_Name([]byte(bucket)))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var partner string
|
||||
if userAgentRow != nil && userAgentRow.UserAgent != nil {
|
||||
entries, err := useragent.ParseEntries(userAgentRow.UserAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, iterPartner := range partnerNames {
|
||||
if entries[0].Product == iterPartner {
|
||||
partner = iterPartner
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, ok := usages[partner]; !ok {
|
||||
usages[partner] = accounting.ProjectUsage{Since: since, Before: before}
|
||||
}
|
||||
usage := usages[partner]
|
||||
|
||||
storageTalliesRows, err := db.db.QueryContext(ctx, storageQuery, projectID[:], []byte(bucket), since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// generating tallies for each bucket name.
|
||||
|
||||
var prevTally *accounting.BucketStorageTally
|
||||
for storageTalliesRows.Next() {
|
||||
tally := accounting.BucketStorageTally{}
|
||||
|
||||
@ -553,8 +606,17 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
|
||||
tally.TotalBytes = inline + remote
|
||||
}
|
||||
|
||||
tally.BucketName = bucket
|
||||
storageTallies = append(storageTallies, &tally)
|
||||
if prevTally == nil {
|
||||
prevTally = &tally
|
||||
continue
|
||||
}
|
||||
|
||||
hours := prevTally.IntervalStart.Sub(tally.IntervalStart).Hours()
|
||||
usage.Storage += memory.Size(tally.TotalBytes).Float64() * hours
|
||||
usage.SegmentCount += float64(tally.TotalSegmentCount) * hours
|
||||
usage.ObjectCount += float64(tally.ObjectCount) * hours
|
||||
|
||||
prevTally = &tally
|
||||
}
|
||||
|
||||
err = errs.Combine(storageTalliesRows.Err(), storageTalliesRows.Close())
|
||||
@ -562,53 +624,21 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bucketsTallies[bucket] = storageTallies
|
||||
}
|
||||
|
||||
totalEgress, err := db.getTotalEgress(ctx, projectID, since, before)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage = new(accounting.ProjectUsage)
|
||||
usage.Egress = memory.Size(totalEgress).Int64()
|
||||
// sum up storage, objects, and segments
|
||||
for _, tallies := range bucketsTallies {
|
||||
for i := len(tallies) - 1; i > 0; i-- {
|
||||
current := (tallies)[i]
|
||||
hours := (tallies)[i-1].IntervalStart.Sub(current.IntervalStart).Hours()
|
||||
usage.Storage += memory.Size(current.Bytes()).Float64() * hours
|
||||
usage.SegmentCount += float64(current.TotalSegmentCount) * hours
|
||||
usage.ObjectCount += float64(current.ObjectCount) * hours
|
||||
totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, []byte(bucket), projectID[:], since, before, pb.PieceAction_GET)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var egress int64
|
||||
if err = totalEgressRow.Scan(&egress); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage.Egress += egress
|
||||
|
||||
usages[partner] = usage
|
||||
}
|
||||
|
||||
usage.Since = since
|
||||
usage.Before = before
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// getTotalEgress returns total egress (settled + inline) of each bucket_bandwidth_rollup
|
||||
// in selected time period, project id.
|
||||
// only process PieceAction_GET.
|
||||
func (db *ProjectAccounting) getTotalEgress(ctx context.Context, projectID uuid.UUID, since, before time.Time) (totalEgress int64, err error) {
|
||||
totalEgressQuery := db.db.Rebind(`
|
||||
SELECT
|
||||
COALESCE(SUM(settled) + SUM(inline), 0)
|
||||
FROM
|
||||
bucket_bandwidth_rollups
|
||||
WHERE
|
||||
project_id = ? AND
|
||||
interval_start >= ? AND
|
||||
interval_start < ? AND
|
||||
action = ?;
|
||||
`)
|
||||
|
||||
totalEgressRow := db.db.QueryRowContext(ctx, totalEgressQuery, projectID[:], since, before, pb.PieceAction_GET)
|
||||
|
||||
err = totalEgressRow.Scan(&totalEgress)
|
||||
|
||||
return totalEgress, err
|
||||
return usages, nil
|
||||
}
|
||||
|
||||
// GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period.
|
||||
|
@ -11,8 +11,10 @@ import (
|
||||
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite/accounting"
|
||||
"storj.io/storj/satellite/console"
|
||||
@ -185,14 +187,7 @@ func Test_GetProjectTotal(t *testing.T) {
|
||||
// The 3rd tally is only present to prevent CreateStorageTally from skipping the 2nd.
|
||||
var tallies []accounting.BucketStorageTally
|
||||
for i := 0; i < 3; i++ {
|
||||
tally := accounting.BucketStorageTally{
|
||||
BucketName: bucketName,
|
||||
ProjectID: projectID,
|
||||
IntervalStart: time.Time{}.Add(time.Duration(i) * time.Hour),
|
||||
TotalBytes: int64(testrand.Intn(1000)),
|
||||
ObjectCount: int64(testrand.Intn(1000)),
|
||||
TotalSegmentCount: int64(testrand.Intn(1000)),
|
||||
}
|
||||
tally := randTally(bucketName, projectID, time.Time{}.Add(time.Duration(i)*time.Hour))
|
||||
tallies = append(tallies, tally)
|
||||
require.NoError(t, db.ProjectAccounting().CreateStorageTally(ctx, tally))
|
||||
}
|
||||
@ -200,14 +195,7 @@ func Test_GetProjectTotal(t *testing.T) {
|
||||
var rollups []orders.BucketBandwidthRollup
|
||||
var expectedEgress int64
|
||||
for i := 0; i < 2; i++ {
|
||||
rollup := orders.BucketBandwidthRollup{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucketName,
|
||||
Action: pb.PieceAction_GET,
|
||||
IntervalStart: tallies[i].IntervalStart,
|
||||
Inline: int64(testrand.Intn(1000)),
|
||||
Settled: int64(testrand.Intn(1000)),
|
||||
}
|
||||
rollup := randRollup(bucketName, projectID, tallies[i].IntervalStart)
|
||||
rollups = append(rollups, rollup)
|
||||
expectedEgress += rollup.Inline + rollup.Settled
|
||||
}
|
||||
@ -245,3 +233,182 @@ func Test_GetProjectTotal(t *testing.T) {
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func Test_GetProjectTotalByPartner(t *testing.T) {
|
||||
const (
|
||||
epsilon = 1e-8
|
||||
usagePeriod = time.Hour
|
||||
tallyRollupCount = 2
|
||||
)
|
||||
since := time.Time{}
|
||||
before := since.Add(2 * usagePeriod)
|
||||
|
||||
testplanet.Run(t, testplanet.Config{SatelliteCount: 1, StorageNodeCount: 1},
|
||||
func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
sat := planet.Satellites[0]
|
||||
|
||||
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||
FullName: "Test User",
|
||||
Email: "user@mail.test",
|
||||
}, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
project, err := sat.AddProject(ctx, user.ID, "testproject")
|
||||
require.NoError(t, err)
|
||||
|
||||
type expectedTotal struct {
|
||||
storage float64
|
||||
segments float64
|
||||
objects float64
|
||||
egress int64
|
||||
}
|
||||
expectedTotals := make(map[string]expectedTotal)
|
||||
var beforeTotal expectedTotal
|
||||
|
||||
requireTotal := func(t *testing.T, expected expectedTotal, expectedSince, expectedBefore time.Time, actual accounting.ProjectUsage) {
|
||||
require.InDelta(t, expected.storage, actual.Storage, epsilon)
|
||||
require.InDelta(t, expected.segments, actual.SegmentCount, epsilon)
|
||||
require.InDelta(t, expected.objects, actual.ObjectCount, epsilon)
|
||||
require.Equal(t, expected.egress, actual.Egress)
|
||||
require.Equal(t, expectedSince, actual.Since)
|
||||
require.Equal(t, expectedBefore, actual.Before)
|
||||
}
|
||||
|
||||
partnerNames := []string{"", "partner1", "partner2"}
|
||||
for _, name := range partnerNames {
|
||||
total := expectedTotal{}
|
||||
|
||||
bucket := storj.Bucket{
|
||||
ID: testrand.UUID(),
|
||||
Name: testrand.BucketName(),
|
||||
ProjectID: project.ID,
|
||||
}
|
||||
if name != "" {
|
||||
bucket.UserAgent = []byte(name)
|
||||
}
|
||||
_, err := sat.DB.Buckets().CreateBucket(ctx, bucket)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We use multiple tallies and rollups to ensure that
|
||||
// GetProjectTotalByPartner is capable of summing them.
|
||||
for i := 0; i <= tallyRollupCount; i++ {
|
||||
tally := randTally(bucket.Name, project.ID, since.Add(time.Duration(i)*usagePeriod/tallyRollupCount))
|
||||
require.NoError(t, sat.DB.ProjectAccounting().CreateStorageTally(ctx, tally))
|
||||
|
||||
// The last tally's usage data is unused.
|
||||
usageHours := (usagePeriod / tallyRollupCount).Hours()
|
||||
if i < tallyRollupCount {
|
||||
total.storage += float64(tally.Bytes()) * usageHours
|
||||
total.objects += float64(tally.ObjectCount) * usageHours
|
||||
total.segments += float64(tally.TotalSegmentCount) * usageHours
|
||||
}
|
||||
|
||||
if i < tallyRollupCount-1 {
|
||||
beforeTotal.storage += float64(tally.Bytes()) * usageHours
|
||||
beforeTotal.objects += float64(tally.ObjectCount) * usageHours
|
||||
beforeTotal.segments += float64(tally.TotalSegmentCount) * usageHours
|
||||
}
|
||||
}
|
||||
|
||||
var rollups []orders.BucketBandwidthRollup
|
||||
for i := 0; i < tallyRollupCount; i++ {
|
||||
rollup := randRollup(bucket.Name, project.ID, since.Add(time.Duration(i)*usagePeriod/tallyRollupCount))
|
||||
rollups = append(rollups, rollup)
|
||||
total.egress += rollup.Inline + rollup.Settled
|
||||
|
||||
if i < tallyRollupCount {
|
||||
beforeTotal.egress += rollup.Inline + rollup.Settled
|
||||
}
|
||||
}
|
||||
require.NoError(t, sat.DB.Orders().UpdateBandwidthBatch(ctx, rollups))
|
||||
|
||||
expectedTotals[name] = total
|
||||
}
|
||||
|
||||
t.Run("sum all partner usages", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, usages, 1)
|
||||
require.Contains(t, usages, "")
|
||||
|
||||
var summedTotal expectedTotal
|
||||
for _, total := range expectedTotals {
|
||||
summedTotal.storage += total.storage
|
||||
summedTotal.segments += total.segments
|
||||
summedTotal.objects += total.objects
|
||||
summedTotal.egress += total.egress
|
||||
}
|
||||
|
||||
requireTotal(t, summedTotal, since, before, usages[""])
|
||||
})
|
||||
|
||||
t.Run("individual partner usages", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, partnerNames, since, before)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, usages, len(expectedTotals))
|
||||
for _, name := range partnerNames {
|
||||
require.Contains(t, usages, name)
|
||||
}
|
||||
|
||||
for partner, usage := range usages {
|
||||
requireTotal(t, expectedTotals[partner], since, before, usage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("select one partner usage and sum remaining usages", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
partner := partnerNames[len(partnerNames)-1]
|
||||
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, []string{partner}, since, before)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, usages, 2)
|
||||
require.Contains(t, usages, "")
|
||||
require.Contains(t, usages, partner)
|
||||
|
||||
var summedTotal expectedTotal
|
||||
for _, partner := range partnerNames[:len(partnerNames)-1] {
|
||||
summedTotal.storage += expectedTotals[partner].storage
|
||||
summedTotal.segments += expectedTotals[partner].segments
|
||||
summedTotal.objects += expectedTotals[partner].objects
|
||||
summedTotal.egress += expectedTotals[partner].egress
|
||||
}
|
||||
|
||||
requireTotal(t, expectedTotals[partner], since, before, usages[partner])
|
||||
requireTotal(t, summedTotal, since, before, usages[""])
|
||||
})
|
||||
|
||||
t.Run("ensure the 'before' arg is exclusive", func(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
before := since.Add(usagePeriod)
|
||||
usages, err := sat.DB.ProjectAccounting().GetProjectTotalByPartner(ctx, project.ID, nil, since, before)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, usages, 1)
|
||||
require.Contains(t, usages, "")
|
||||
requireTotal(t, beforeTotal, since, before, usages[""])
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func randTally(bucketName string, projectID uuid.UUID, intervalStart time.Time) accounting.BucketStorageTally {
|
||||
return accounting.BucketStorageTally{
|
||||
BucketName: bucketName,
|
||||
ProjectID: projectID,
|
||||
IntervalStart: intervalStart,
|
||||
TotalBytes: int64(testrand.Intn(1000)),
|
||||
ObjectCount: int64(testrand.Intn(1000)),
|
||||
TotalSegmentCount: int64(testrand.Intn(1000)),
|
||||
}
|
||||
}
|
||||
|
||||
func randRollup(bucketName string, projectID uuid.UUID, intervalStart time.Time) orders.BucketBandwidthRollup {
|
||||
return orders.BucketBandwidthRollup{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucketName,
|
||||
IntervalStart: intervalStart,
|
||||
Action: pb.PieceAction_GET,
|
||||
Inline: int64(testrand.Intn(1000)),
|
||||
Settled: int64(testrand.Intn(1000)),
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user