satellite/{payments, db}: aggregate invoice items if many projects

Implemented invoice items aggregation if projects count is more than 83 for a single invoice.

Change-Id: I6bce81e537eaaddd9297a85718b594047436964a
This commit is contained in:
Vitalii 2023-11-21 15:56:22 +02:00 committed by Storj Robot
parent f09f352628
commit 6b1c62d7b2
4 changed files with 274 additions and 14 deletions

View File

@ -19,6 +19,8 @@ var ErrProjectRecordExists = Error.New("invoice project record already exists")
type ProjectRecordsDB interface {
// Create creates new invoice project record with credits spendings in the DB.
Create(ctx context.Context, records []CreateProjectRecord, start, end time.Time) error
// CreateToBeAggregated creates new to be aggregated invoice project record with credits spendings in the DB.
CreateToBeAggregated(ctx context.Context, records []CreateProjectRecord, start, end time.Time) error
// Check checks if invoice project record for specified project and billing period exists.
Check(ctx context.Context, projectID uuid.UUID, start, end time.Time) error
// Get returns record for specified project and billing period.
@ -28,6 +30,9 @@ type ProjectRecordsDB interface {
// ListUnapplied returns project records page with unapplied project records.
// Cursor is not included into listing results.
ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error)
// ListToBeAggregated returns to be aggregated project records page with unapplied project records.
// Cursor is not included into listing results.
ListToBeAggregated(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error)
}
// CreateProjectRecord holds info needed for creation new invoice

View File

@ -39,8 +39,14 @@ var (
mon = monkit.Package()
)
// hoursPerMonth is the number of months in a billing month. For the purpose of billing, the billing month is always 30 days.
const hoursPerMonth = 24 * 30
const (
// hoursPerMonth is the number of months in a billing month. For the purpose of billing, the billing month is always 30 days.
hoursPerMonth = 24 * 30
storageInvoiceItemDesc = " - Segment Storage (MB-Month)"
egressInvoiceItemDesc = " - Egress Bandwidth (MB)"
segmentInvoiceItemDesc = " - Segment Fee (Segment-Month)"
)
// Config stores needed information for payment service initialization.
type Config struct {
@ -170,7 +176,8 @@ func (service *Service) PrepareInvoiceProjectRecords(ctx context.Context, period
}
func (service *Service) processCustomers(ctx context.Context, customers []Customer, start, end time.Time) (int, error) {
var allRecords []CreateProjectRecord
var regularRecords []CreateProjectRecord
var recordsToAggregate []CreateProjectRecord
for _, customer := range customers {
if inactive, err := service.isUserInactive(ctx, customer.UserID); err != nil {
return 0, err
@ -188,10 +195,26 @@ func (service *Service) processCustomers(ctx context.Context, customers []Custom
return 0, err
}
allRecords = append(allRecords, records...)
// We generate 3 invoice items for each user project which means,
// we can support only 83 projects in a single invoice (249 invoice items).
if len(projects) > 83 {
recordsToAggregate = append(recordsToAggregate, records...)
} else {
regularRecords = append(regularRecords, records...)
}
}
return len(allRecords), service.db.ProjectRecords().Create(ctx, allRecords, start, end)
err := service.db.ProjectRecords().Create(ctx, regularRecords, start, end)
if err != nil {
return 0, err
}
err = service.db.ProjectRecords().CreateToBeAggregated(ctx, recordsToAggregate, start, end)
if err != nil {
return 0, err
}
return len(recordsToAggregate) + len(regularRecords), nil
}
// createProjectRecords creates invoice project record if none exists.
@ -273,7 +296,54 @@ func (service *Service) InvoiceApplyProjectRecords(ctx context.Context, period t
}
}
service.log.Info("Processed project records.",
service.log.Info("Processed regular project records.",
zap.Int("Total", totalRecords),
zap.Int("Skipped", totalSkipped))
return nil
}
// InvoiceApplyToBeAggregatedProjectRecords iterates through to be aggregated invoice project records and creates invoice line items
// for stripe customer.
func (service *Service) InvoiceApplyToBeAggregatedProjectRecords(ctx context.Context, period time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
now := service.nowFn().UTC()
utc := period.UTC()
start := time.Date(utc.Year(), utc.Month(), 1, 0, 0, 0, 0, time.UTC)
end := time.Date(utc.Year(), utc.Month()+1, 1, 0, 0, 0, 0, time.UTC)
if end.After(now) {
return Error.New("allowed for past periods only")
}
var totalRecords int
var totalSkipped int
for {
if err = ctx.Err(); err != nil {
return Error.Wrap(err)
}
// we are always starting from offset 0 because applyProjectRecords is changing project record state to applied
recordsPage, err := service.db.ProjectRecords().ListToBeAggregated(ctx, uuid.UUID{}, service.listingLimit, start, end)
if err != nil {
return Error.Wrap(err)
}
totalRecords += len(recordsPage.Records)
skipped, err := service.applyToBeAggregatedProjectRecords(ctx, recordsPage.Records)
if err != nil {
return Error.Wrap(err)
}
totalSkipped += skipped
if !recordsPage.Next {
break
}
}
service.log.Info("Processed aggregated project records.",
zap.Int("Total", totalRecords),
zap.Int("Skipped", totalSkipped))
return nil
@ -478,6 +548,51 @@ func (service *Service) applyProjectRecords(ctx context.Context, records []Proje
return skipCount, errGrp.Err()
}
// applyToBeAggregatedProjectRecords applies to be aggregated invoice intents as invoice line items to stripe customer.
func (service *Service) applyToBeAggregatedProjectRecords(ctx context.Context, records []ProjectRecord) (skipCount int, err error) {
defer mon.Task()(&ctx)(&err)
for _, record := range records {
if err = ctx.Err(); err != nil {
return 0, errs.Wrap(err)
}
proj, err := service.projectsDB.Get(ctx, record.ProjectID)
if err != nil {
service.log.Error("project ID for corresponding project record not found", zap.Stringer("Record ID", record.ID), zap.Stringer("Project ID", record.ProjectID))
return 0, errs.Wrap(err)
}
if inactive, err := service.isUserInactive(ctx, proj.OwnerID); err != nil {
return 0, errs.Wrap(err)
} else if inactive {
skipCount++
continue
}
cusID, err := service.db.Customers().GetCustomerID(ctx, proj.OwnerID)
if err != nil {
if errors.Is(err, ErrNoCustomer) {
service.log.Warn("Stripe customer does not exist for project owner.", zap.Stringer("Owner ID", proj.OwnerID), zap.Stringer("Project ID", proj.ID))
continue
}
return 0, errs.Wrap(err)
}
record := record
skipped, err := service.processProjectRecord(ctx, cusID, proj.Name, record)
if err != nil {
return 0, errs.Wrap(err)
}
if skipped {
skipCount++
}
}
return skipCount, nil
}
// createInvoiceItems creates invoice line items for stripe customer.
func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName string, record ProjectRecord) (skipped bool, err error) {
defer mon.Task()(&ctx)(&err)
@ -495,7 +610,7 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return false, err
}
items := service.InvoiceItemsFromProjectUsage(projName, usages)
items := service.InvoiceItemsFromProjectUsage(projName, usages, false)
for _, item := range items {
item.Params = stripe.Params{Context: ctx}
item.Currency = stripe.String(string(stripe.CurrencyUSD))
@ -511,8 +626,118 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
return false, nil
}
type usage int
const (
storage usage = 0
egress usage = 1
segment usage = 2
)
// processProjectRecord creates or updates invoice line items for stripe customer.
func (service *Service) processProjectRecord(ctx context.Context, cusID, projName string, record ProjectRecord) (skipped bool, err error) {
defer mon.Task()(&ctx)(&err)
if err = service.db.ProjectRecords().Consume(ctx, record.ID); err != nil {
return false, err
}
if service.skipEmptyInvoices && doesProjectRecordHaveNoUsage(record) {
return true, nil
}
usages, err := service.usageDB.GetProjectTotalByPartner(ctx, record.ProjectID, service.partnerNames, record.PeriodStart, record.PeriodEnd)
if err != nil {
return false, err
}
newItems := service.InvoiceItemsFromProjectUsage(projName, usages, true)
existingItems, err := service.getExistingInvoiceItems(ctx, cusID)
if err != nil {
return false, err
}
if existingItems[segment] == nil || existingItems[storage] == nil || existingItems[egress] == nil {
for _, item := range newItems {
item.Params = stripe.Params{Context: ctx}
item.Currency = stripe.String(string(stripe.CurrencyUSD))
item.Customer = stripe.String(cusID)
item.AddMetadata("projectID", record.ProjectID.String())
_, err = service.stripeClient.InvoiceItems().New(item)
if err != nil {
return false, err
}
}
} else {
err = service.updateExistingInvoiceItems(ctx, existingItems, newItems)
if err != nil {
return false, err
}
}
return false, nil
}
// getExistingInvoiceItems lists 3 existing pending invoice line items for stripe customer.
func (service *Service) getExistingInvoiceItems(ctx context.Context, cusID string) (map[usage]*stripe.InvoiceItem, error) {
existingItemsIter := service.stripeClient.InvoiceItems().List(&stripe.InvoiceItemListParams{
Customer: &cusID,
Pending: stripe.Bool(true),
ListParams: stripe.ListParams{
Context: ctx,
Limit: stripe.Int64(3),
},
})
items := map[usage]*stripe.InvoiceItem{
storage: nil,
egress: nil,
segment: nil,
}
for existingItemsIter.Next() {
item := existingItemsIter.InvoiceItem()
if strings.Contains(item.Description, storageInvoiceItemDesc) {
items[storage] = item
} else if strings.Contains(item.Description, egressInvoiceItemDesc) {
items[egress] = item
} else if strings.Contains(item.Description, segmentInvoiceItemDesc) {
items[segment] = item
}
}
return items, existingItemsIter.Err()
}
// updateExistingInvoiceItems updates 3 existing pending invoice line items for stripe customer.
func (service *Service) updateExistingInvoiceItems(ctx context.Context, existingItems map[usage]*stripe.InvoiceItem, newItems []*stripe.InvoiceItemParams) (err error) {
for _, item := range newItems {
if strings.Contains(*item.Description, storageInvoiceItemDesc) {
existingItems[storage].Quantity += *item.Quantity
} else if strings.Contains(*item.Description, egressInvoiceItemDesc) {
existingItems[egress].Quantity += *item.Quantity
} else if strings.Contains(*item.Description, segmentInvoiceItemDesc) {
existingItems[segment].Quantity += *item.Quantity
}
}
for _, item := range existingItems {
_, err = service.stripeClient.InvoiceItems().Update(item.ID, &stripe.InvoiceItemParams{
Params: stripe.Params{Context: ctx},
Quantity: stripe.Int64(item.Quantity),
})
if err != nil {
return err
}
}
return nil
}
// InvoiceItemsFromProjectUsage calculates Stripe invoice item from project usage.
func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsages map[string]accounting.ProjectUsage) (result []*stripe.InvoiceItemParams) {
func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsages map[string]accounting.ProjectUsage, aggregated bool) (result []*stripe.InvoiceItemParams) {
var partners []string
if len(partnerUsages) == 0 {
partners = []string{""}
@ -535,22 +760,26 @@ func (service *Service) InvoiceItemsFromProjectUsage(projName string, partnerUsa
prefix += " (" + partner + ")"
}
if aggregated {
prefix = "All projects"
}
projectItem := &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Storage (MB-Month)")
projectItem.Description = stripe.String(prefix + storageInvoiceItemDesc)
projectItem.Quantity = stripe.Int64(storageMBMonthDecimal(usage.Storage).IntPart())
storagePrice, _ := priceModel.StorageMBMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(storagePrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Egress Bandwidth (MB)")
projectItem.Description = stripe.String(prefix + egressInvoiceItemDesc)
projectItem.Quantity = stripe.Int64(egressMBDecimal(usage.Egress).IntPart())
egressPrice, _ := priceModel.EgressMBCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(egressPrice)
result = append(result, projectItem)
projectItem = &stripe.InvoiceItemParams{}
projectItem.Description = stripe.String(prefix + " - Segment Fee (Segment-Month)")
projectItem.Description = stripe.String(prefix + segmentInvoiceItemDesc)
projectItem.Quantity = stripe.Int64(segmentMonthDecimal(usage.SegmentCount).IntPart())
segmentPrice, _ := priceModel.SegmentMonthCents.Float64()
projectItem.UnitAmountDecimal = stripe.Float64(segmentPrice)
@ -1033,6 +1262,7 @@ func (service *Service) GenerateInvoices(ctx context.Context, period time.Time)
}{
{"Preparing invoice project records", service.PrepareInvoiceProjectRecords},
{"Applying invoice project records", service.InvoiceApplyProjectRecords},
{"Applying to be aggregated invoice project records", service.InvoiceApplyToBeAggregatedProjectRecords},
{"Creating invoices", service.CreateInvoices},
} {
service.log.Info(subFn.Description)

View File

@ -773,7 +773,7 @@ func TestService_InvoiceItemsFromProjectUsage(t *testing.T) {
},
}
items := planet.Satellites[0].API.Payments.StripeService.InvoiceItemsFromProjectUsage(projectName, usage)
items := planet.Satellites[0].API.Payments.StripeService.InvoiceItemsFromProjectUsage(projectName, usage, false)
require.Len(t, items, len(usage)*3)
for i, tt := range []struct {

View File

@ -28,6 +28,8 @@ const (
invoiceProjectRecordStateUnapplied invoiceProjectRecordState = 0
// invoice project record has been used during creating customer invoice.
invoiceProjectRecordStateConsumed invoiceProjectRecordState = 1
// invoice project record is not yet applied to customer invoice and has to be aggregated with other items.
invoiceProjectRecordStateToBeAggregated invoiceProjectRecordState = 2
)
// Int returns intent state as int.
@ -46,6 +48,17 @@ type invoiceProjectRecords struct {
func (db *invoiceProjectRecords) Create(ctx context.Context, records []stripe.CreateProjectRecord, start, end time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
return db.createWithState(ctx, records, invoiceProjectRecordStateUnapplied, start, end)
}
// CreateToBeAggregated creates new to be aggregated invoice project record in the DB.
func (db *invoiceProjectRecords) CreateToBeAggregated(ctx context.Context, records []stripe.CreateProjectRecord, start, end time.Time) (err error) {
defer mon.Task()(&ctx)(&err)
return db.createWithState(ctx, records, invoiceProjectRecordStateToBeAggregated, start, end)
}
func (db *invoiceProjectRecords) createWithState(ctx context.Context, records []stripe.CreateProjectRecord, state invoiceProjectRecordState, start, end time.Time) error {
return db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
for _, record := range records {
id, err := uuid.New()
@ -60,7 +73,7 @@ func (db *invoiceProjectRecords) Create(ctx context.Context, records []stripe.Cr
dbx.StripecoinpaymentsInvoiceProjectRecord_Egress(record.Egress),
dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodStart(start),
dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodEnd(end),
dbx.StripecoinpaymentsInvoiceProjectRecord_State(invoiceProjectRecordStateUnapplied.Int()),
dbx.StripecoinpaymentsInvoiceProjectRecord_State(state.Int()),
dbx.StripecoinpaymentsInvoiceProjectRecord_Create_Fields{
Segments: dbx.StripecoinpaymentsInvoiceProjectRecord_Segments(int64(record.Segments)),
},
@ -129,11 +142,23 @@ func (db *invoiceProjectRecords) Consume(ctx context.Context, id uuid.UUID) (err
return err
}
// ListToBeAggregated returns to be aggregated project records page with unapplied project records.
// Cursor is not included into listing results.
func (db *invoiceProjectRecords) ListToBeAggregated(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) {
defer mon.Task()(&ctx)(&err)
return db.list(ctx, cursor, limit, invoiceProjectRecordStateToBeAggregated.Int(), start, end)
}
// ListUnapplied returns project records page with unapplied project records.
// Cursor is not included into listing results.
func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) {
defer mon.Task()(&ctx)(&err)
return db.list(ctx, cursor, limit, invoiceProjectRecordStateUnapplied.Int(), start, end)
}
func (db *invoiceProjectRecords) list(ctx context.Context, cursor uuid.UUID, limit, state int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) {
err = withRows(db.db.QueryContext(ctx, db.db.Rebind(`
SELECT
id, project_id, storage, egress, segments, period_start, period_end, state
@ -143,7 +168,7 @@ func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, cursor uuid.
id > ? AND period_start = ? AND period_end = ? AND state = ?
ORDER BY id
LIMIT ?
`), cursor, start, end, invoiceProjectRecordStateUnapplied.Int(), limit+1))(func(rows tagsql.Rows) error {
`), cursor, start, end, state, limit+1))(func(rows tagsql.Rows) error {
for rows.Next() {
var record stripe.ProjectRecord
err := rows.Scan(&record.ID, &record.ProjectID, &record.Storage, &record.Egress, &record.Segments, &record.PeriodStart, &record.PeriodEnd, &record.State)