diff --git a/cmd/satellite/main.go b/cmd/satellite/main.go index 5e9ed6cab..ce6f5a90a 100644 --- a/cmd/satellite/main.go +++ b/cmd/satellite/main.go @@ -208,7 +208,14 @@ var ( Long: "Applies free tier coupon to Stripe customers without a coupon", RunE: cmdApplyFreeTierCoupons, } - createCustomerBalanceInvoiceItems = &cobra.Command{ + setInvoiceStatusCmd = &cobra.Command{ + Use: "set-invoice-status [start-period] [end-period] [status]", + Short: "set all open invoices status", + Long: "set all open invoices in the specified date ranges to the provided status. Period is a UTC date formatted like YYYY-MM.", + Args: cobra.ExactArgs(3), + RunE: cmdSetInvoiceStatus, + } + createCustomerBalanceInvoiceItemsCmd = &cobra.Command{ Use: "create-balance-invoice-items", Short: "Creates stripe invoice line items for stripe customer balance", Long: "Creates stripe invoice line items for stripe customer balances obtained from past invoices and other miscellaneous charges.", @@ -342,6 +349,9 @@ var ( Database string `help:"satellite database connection string" releaseDefault:"postgres://" devDefault:"postgres://"` Before string `help:"select only exited nodes before this UTC date formatted like YYYY-MM. Date cannot be newer than the current time (required)"` } + setInvoiceStatusCfg struct { + DryRun bool `help:"do not update stripe" default:"false"` + } confDir string identityDir string @@ -381,7 +391,8 @@ func init() { compensationCmd.AddCommand(recordPeriodCmd) compensationCmd.AddCommand(recordOneOffPaymentsCmd) billingCmd.AddCommand(applyFreeTierCouponsCmd) - billingCmd.AddCommand(createCustomerBalanceInvoiceItems) + billingCmd.AddCommand(setInvoiceStatusCmd) + billingCmd.AddCommand(createCustomerBalanceInvoiceItemsCmd) billingCmd.AddCommand(prepareCustomerInvoiceRecordsCmd) billingCmd.AddCommand(createCustomerProjectInvoiceItemsCmd) billingCmd.AddCommand(createCustomerInvoicesCmd) @@ -413,7 +424,9 @@ func init() { process.Bind(reportsVerifyGEReceiptCmd, &reportsVerifyGracefulExitReceiptCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) process.Bind(partnerAttributionCmd, &partnerAttribtionCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) process.Bind(applyFreeTierCouponsCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) - process.Bind(createCustomerBalanceInvoiceItems, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) + process.Bind(setInvoiceStatusCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) + process.Bind(setInvoiceStatusCmd, &setInvoiceStatusCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) + process.Bind(createCustomerBalanceInvoiceItemsCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) process.Bind(prepareCustomerInvoiceRecordsCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) process.Bind(createCustomerProjectInvoiceItemsCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) process.Bind(createCustomerInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir)) @@ -754,6 +767,30 @@ func cmdValueAttribution(cmd *cobra.Command, args []string) (err error) { return reports.GenerateAttributionCSV(ctx, partnerAttribtionCfg.Database, start, end, userAgents, file) } +// cmdSetInvoiceStatus sets the status of all open invoices within the provided period to the provided status. +// args[0] is the start of the period in YYYY-MM format. +// args[1] is the end of the period in YYYY-MM format. +// args[2] is the status to set the invoices to. +func cmdSetInvoiceStatus(cmd *cobra.Command, args []string) (err error) { + ctx, _ := process.Ctx(cmd) + + periodStart, err := parseYearMonth(args[0]) + if err != nil { + return err + } + + periodEnd, err := parseYearMonth(args[1]) + if err != nil { + return err + } + // parseYearMonth returns the first day of the month, but we want the period end to be the last day of the month + periodEnd = periodEnd.AddDate(0, 1, -1) + + return runBillingCmd(ctx, func(ctx context.Context, payments *stripe.Service, _ satellite.DB) error { + return payments.SetInvoiceStatus(ctx, periodStart, periodEnd, args[2], setInvoiceStatusCfg.DryRun) + }) +} + func cmdCreateCustomerBalanceInvoiceItems(cmd *cobra.Command, _ []string) (err error) { ctx, _ := process.Ctx(cmd) diff --git a/satellite/payments/stripe/client.go b/satellite/payments/stripe/client.go index 2db72d3cc..3c461f930 100644 --- a/satellite/payments/stripe/client.go +++ b/satellite/payments/stripe/client.go @@ -64,6 +64,8 @@ type Invoices interface { Pay(id string, params *stripe.InvoicePayParams) (*stripe.Invoice, error) Del(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error) Get(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error) + MarkUncollectible(id string, params *stripe.InvoiceMarkUncollectibleParams) (*stripe.Invoice, error) + VoidInvoice(id string, params *stripe.InvoiceVoidParams) (*stripe.Invoice, error) } // InvoiceItems Stripe InvoiceItems interface. diff --git a/satellite/payments/stripe/service.go b/satellite/payments/stripe/service.go index 8fc9617f0..31742e33a 100644 --- a/satellite/payments/stripe/service.go +++ b/satellite/payments/stripe/service.go @@ -860,6 +860,86 @@ func (service *Service) createInvoices(ctx context.Context, customers []Customer return scheduled, draft, errGrp.Err() } +// SetInvoiceStatus will set all open invoices within the specified date range to the requested status. +func (service *Service) SetInvoiceStatus(ctx context.Context, startPeriod, endPeriod time.Time, status string, dryRun bool) (err error) { + defer mon.Task()(&ctx)(&err) + + switch stripe.InvoiceStatus(strings.ToLower(status)) { + case stripe.InvoiceStatusUncollectible: + err = service.iterateInvoicesInTimeRange(ctx, startPeriod, endPeriod, func(invoiceId string) error { + service.log.Info("updating invoice status to uncollectible", zap.String("invoiceId", invoiceId)) + if !dryRun { + _, err := service.stripeClient.Invoices().MarkUncollectible(invoiceId, &stripe.InvoiceMarkUncollectibleParams{}) + if err != nil { + return Error.Wrap(err) + } + } + return nil + }) + case stripe.InvoiceStatusVoid: + err = service.iterateInvoicesInTimeRange(ctx, startPeriod, endPeriod, func(invoiceId string) error { + service.log.Info("updating invoice status to void", zap.String("invoiceId", invoiceId)) + if !dryRun { + _, err = service.stripeClient.Invoices().VoidInvoice(invoiceId, &stripe.InvoiceVoidParams{}) + if err != nil { + return Error.Wrap(err) + } + } + return nil + }) + case stripe.InvoiceStatusPaid: + err = service.iterateInvoicesInTimeRange(ctx, startPeriod, endPeriod, func(invoiceId string) error { + service.log.Info("updating invoice status to paid", zap.String("invoiceId", invoiceId)) + if !dryRun { + payParams := &stripe.InvoicePayParams{ + Params: stripe.Params{Context: ctx}, + PaidOutOfBand: stripe.Bool(true), + } + _, err = service.stripeClient.Invoices().Pay(invoiceId, payParams) + if err != nil { + return Error.Wrap(err) + } + } + return nil + }) + default: + // unknown + service.log.Error("Unknown status provided. Valid options are uncollectible, void, or paid.", zap.String("status", status)) + return Error.New("unknown status provided") + } + return err +} + +func (service *Service) iterateInvoicesInTimeRange(ctx context.Context, startPeriod, endPeriod time.Time, updateStatus func(string) error) (err error) { + defer mon.Task()(&ctx)(&err) + + params := &stripe.InvoiceListParams{ + ListParams: stripe.ListParams{ + Context: ctx, + Limit: stripe.Int64(100), + }, + Status: stripe.String("open"), + CreatedRange: &stripe.RangeQueryParams{ + GreaterThanOrEqual: startPeriod.Unix(), + LesserThanOrEqual: endPeriod.Unix(), + }, + } + + numInvoices := 0 + invoicesIterator := service.stripeClient.Invoices().List(params) + for invoicesIterator.Next() { + numInvoices++ + stripeInvoice := invoicesIterator.Invoice() + + err := updateStatus(stripeInvoice.ID) + if err != nil { + return Error.Wrap(err) + } + } + service.log.Info("found " + strconv.Itoa(numInvoices) + " total invoices") + return Error.Wrap(invoicesIterator.Err()) +} + // CreateBalanceInvoiceItems will find users with a stripe balance, create an invoice // item with the charges due, and zero out the stripe balance. func (service *Service) CreateBalanceInvoiceItems(ctx context.Context) (err error) { diff --git a/satellite/payments/stripe/service_test.go b/satellite/payments/stripe/service_test.go index f22a325d4..96f53edbb 100644 --- a/satellite/payments/stripe/service_test.go +++ b/satellite/payments/stripe/service_test.go @@ -36,6 +36,292 @@ import ( stripe1 "storj.io/storj/satellite/payments/stripe" ) +func TestService_SetInvoiceStatusUncollectible(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Payments.StripeCoinPayments.ListingLimit = 4 + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + satellite := planet.Satellites[0] + payments := satellite.API.Payments + + invoiceBalance := currency.AmountFromBaseUnits(800, currency.USDollars) + usdCurrency := string(stripe.CurrencyUSD) + + user, err := satellite.AddUser(ctx, console.CreateUser{ + FullName: "testuser", + Email: "user@test", + }, 1) + require.NoError(t, err) + customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + // create invoice item + invItem, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: stripe.Int64(invoiceBalance.BaseUnits()), + Currency: stripe.String(usdCurrency), + Customer: &customer, + }) + require.NoError(t, err) + + InvItems := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + InvItems = append(InvItems, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &invItem.ID, + Amount: &invItem.Amount, + Currency: stripe.String(usdCurrency), + }) + + // create invoice + inv, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: InvItems, + }) + require.NoError(t, err) + + finalizeParams := &stripe.InvoiceFinalizeParams{Params: stripe.Params{Context: ctx}} + + // finalize invoice + inv, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + + // run update invoice status to uncollectible + // beginning of last month + startPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, -1, 0) + // end of current month + endPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, -1) + + t.Run("update invoice status to uncollectible", func(t *testing.T) { + err = payments.StripeService.SetInvoiceStatus(ctx, startPeriod, endPeriod, "uncollectible", false) + require.NoError(t, err) + + iter := satellite.API.Payments.StripeClient.Invoices().List(&stripe.InvoiceListParams{ + ListParams: stripe.ListParams{Context: ctx}, + }) + iter.Next() + require.Equal(t, stripe.InvoiceStatusUncollectible, iter.Invoice().Status) + }) + }) +} + +func TestService_SetInvoiceStatusVoid(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Payments.StripeCoinPayments.ListingLimit = 4 + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + satellite := planet.Satellites[0] + payments := satellite.API.Payments + + invoiceBalance := currency.AmountFromBaseUnits(800, currency.USDollars) + usdCurrency := string(stripe.CurrencyUSD) + + user, err := satellite.AddUser(ctx, console.CreateUser{ + FullName: "testuser", + Email: "user@test", + }, 1) + require.NoError(t, err) + customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + // create invoice item + invItem, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: stripe.Int64(invoiceBalance.BaseUnits()), + Currency: stripe.String(usdCurrency), + Customer: &customer, + }) + require.NoError(t, err) + + InvItems := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + InvItems = append(InvItems, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &invItem.ID, + Amount: &invItem.Amount, + Currency: stripe.String(usdCurrency), + }) + + // create invoice + inv, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: InvItems, + }) + require.NoError(t, err) + + finalizeParams := &stripe.InvoiceFinalizeParams{Params: stripe.Params{Context: ctx}} + + // finalize invoice + inv, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + + // run update invoice status to uncollectible + // beginning of last month + startPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, -1, 0) + // end of current month + endPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, -1) + + t.Run("update invoice status to void", func(t *testing.T) { + err = payments.StripeService.SetInvoiceStatus(ctx, startPeriod, endPeriod, "void", false) + require.NoError(t, err) + + iter := satellite.API.Payments.StripeClient.Invoices().List(&stripe.InvoiceListParams{ + ListParams: stripe.ListParams{Context: ctx}, + }) + iter.Next() + require.Equal(t, stripe.InvoiceStatusVoid, iter.Invoice().Status) + }) + }) +} + +func TestService_SetInvoiceStatusPaid(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Payments.StripeCoinPayments.ListingLimit = 4 + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + satellite := planet.Satellites[0] + payments := satellite.API.Payments + + invoiceBalance := currency.AmountFromBaseUnits(800, currency.USDollars) + usdCurrency := string(stripe.CurrencyUSD) + + user, err := satellite.AddUser(ctx, console.CreateUser{ + FullName: "testuser", + Email: "user@test", + }, 1) + require.NoError(t, err) + customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + // create invoice item + invItem, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: stripe.Int64(invoiceBalance.BaseUnits()), + Currency: stripe.String(usdCurrency), + Customer: &customer, + }) + require.NoError(t, err) + + InvItems := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + InvItems = append(InvItems, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &invItem.ID, + Amount: &invItem.Amount, + Currency: stripe.String(usdCurrency), + }) + + // create invoice + inv, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: InvItems, + }) + require.NoError(t, err) + + finalizeParams := &stripe.InvoiceFinalizeParams{Params: stripe.Params{Context: ctx}} + + // finalize invoice + inv, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + + // run update invoice status to uncollectible + // beginning of last month + startPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, -1, 0) + // end of current month + endPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, -1) + + t.Run("update invoice status to paid", func(t *testing.T) { + err = payments.StripeService.SetInvoiceStatus(ctx, startPeriod, endPeriod, "paid", false) + require.NoError(t, err) + + iter := satellite.API.Payments.StripeClient.Invoices().List(&stripe.InvoiceListParams{ + ListParams: stripe.ListParams{Context: ctx}, + }) + iter.Next() + require.Equal(t, stripe.InvoiceStatusPaid, iter.Invoice().Status) + }) + }) +} + +func TestService_SetInvoiceStatusInvalid(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Payments.StripeCoinPayments.ListingLimit = 4 + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + satellite := planet.Satellites[0] + payments := satellite.API.Payments + + invoiceBalance := currency.AmountFromBaseUnits(800, currency.USDollars) + usdCurrency := string(stripe.CurrencyUSD) + + user, err := satellite.AddUser(ctx, console.CreateUser{ + FullName: "testuser", + Email: "user@test", + }, 1) + require.NoError(t, err) + customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + // create invoice item + invItem, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: stripe.Int64(invoiceBalance.BaseUnits()), + Currency: stripe.String(usdCurrency), + Customer: &customer, + }) + require.NoError(t, err) + + InvItems := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + InvItems = append(InvItems, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &invItem.ID, + Amount: &invItem.Amount, + Currency: stripe.String(usdCurrency), + }) + + // create invoice + inv, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: InvItems, + }) + require.NoError(t, err) + + finalizeParams := &stripe.InvoiceFinalizeParams{Params: stripe.Params{Context: ctx}} + + // finalize invoice + inv, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + + // run update invoice status to uncollectible + // beginning of last month + startPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, -1, 0) + // end of current month + endPeriod := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, -1) + + t.Run("update invoice status to invalid", func(t *testing.T) { + err = payments.StripeService.SetInvoiceStatus(ctx, startPeriod, endPeriod, "not a real status", false) + require.Error(t, err) + }) + }) +} + func TestService_BalanceInvoiceItems(t *testing.T) { testplanet.Run(t, testplanet.Config{ SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, diff --git a/satellite/payments/stripe/stripemock.go b/satellite/payments/stripe/stripemock.go index d0a961495..21b43e62c 100644 --- a/satellite/payments/stripe/stripemock.go +++ b/satellite/payments/stripe/stripemock.go @@ -497,6 +497,32 @@ type mockInvoices struct { invoiceItems *mockInvoiceItems } +func (m *mockInvoices) MarkUncollectible(id string, params *stripe.InvoiceMarkUncollectibleParams) (*stripe.Invoice, error) { + for _, invoices := range m.invoices { + for _, invoice := range invoices { + if invoice.ID == id { + invoice.Status = stripe.InvoiceStatusUncollectible + return invoice, nil + } + } + } + + return nil, errors.New("invoice not found") +} + +func (m *mockInvoices) VoidInvoice(id string, params *stripe.InvoiceVoidParams) (*stripe.Invoice, error) { + for _, invoices := range m.invoices { + for _, invoice := range invoices { + if invoice.ID == id { + invoice.Status = stripe.InvoiceStatusVoid + return invoice, nil + } + } + } + + return nil, errors.New("invoice not found") +} + func newMockInvoices(root *mockStripeState, invoiceItems *mockInvoiceItems) *mockInvoices { return &mockInvoices{ root: root, @@ -639,8 +665,9 @@ func (m *mockInvoices) Pay(id string, params *stripe.InvoicePayParams) (*stripe. invoice.AmountRemaining = 0 return invoice, nil } - } else if invoice.AmountRemaining == 0 { + } else if invoice.AmountRemaining == 0 || (params.PaidOutOfBand != nil && *params.PaidOutOfBand) { invoice.Status = stripe.InvoiceStatusPaid + invoice.AmountRemaining = 0 } return invoice, nil }