diff --git a/satellite/analytics/service.go b/satellite/analytics/service.go index 96d17cb91..58379c725 100644 --- a/satellite/analytics/service.go +++ b/satellite/analytics/service.go @@ -79,6 +79,9 @@ const ( eventProjectDescriptionUpdated = "Project Description Updated" eventProjectStorageLimitUpdated = "Project Storage Limit Updated" eventProjectBandwidthLimitUpdated = "Project Bandwidth Limit Updated" + eventAccountFrozen = "Account Frozen" + eventAccountFreezeWarning = "Account Freeze Warning" + eventUnpaidLargeInvoice = "Large Invoice Unpaid" ) var ( @@ -303,6 +306,55 @@ func (service *Service) TrackProjectCreated(userID uuid.UUID, email string, proj }) } +// TrackAccountFrozen sends an account frozen event to Segment. +func (service *Service) TrackAccountFrozen(userID uuid.UUID, email string) { + if !service.config.Enabled { + return + } + + props := segment.NewProperties() + props.Set("email", email) + + service.enqueueMessage(segment.Track{ + UserId: userID.String(), + Event: service.satelliteName + " " + eventAccountFrozen, + Properties: props, + }) +} + +// TrackAccountFreezeWarning sends an account freeze warning event to Segment. +func (service *Service) TrackAccountFreezeWarning(userID uuid.UUID, email string) { + if !service.config.Enabled { + return + } + + props := segment.NewProperties() + props.Set("email", email) + + service.enqueueMessage(segment.Track{ + UserId: userID.String(), + Event: service.satelliteName + " " + eventAccountFreezeWarning, + Properties: props, + }) +} + +// TrackLargeUnpaidInvoice sends an event to Segment indicating that a user has not paid a large invoice. +func (service *Service) TrackLargeUnpaidInvoice(invID string, userID uuid.UUID, email string) { + if !service.config.Enabled { + return + } + + props := segment.NewProperties() + props.Set("email", email) + props.Set("invoice", invID) + + service.enqueueMessage(segment.Track{ + UserId: userID.String(), + Event: service.satelliteName + " " + eventUnpaidLargeInvoice, + Properties: props, + }) +} + // TrackAccessGrantCreated sends an "Access Grant Created" event to Segment. func (service *Service) TrackAccessGrantCreated(userID uuid.UUID, email string) { if !service.config.Enabled { diff --git a/satellite/console/accountfreezes.go b/satellite/console/accountfreezes.go index 191a22ecd..7ce31fa8e 100644 --- a/satellite/console/accountfreezes.go +++ b/satellite/console/accountfreezes.go @@ -25,6 +25,8 @@ type AccountFreezeEvents interface { Upsert(ctx context.Context, event *AccountFreezeEvent) (*AccountFreezeEvent, error) // Get is a method for querying account freeze event from the database by user ID and event type. Get(ctx context.Context, userID uuid.UUID, eventType AccountFreezeEventType) (*AccountFreezeEvent, error) + // GetAll is a method for querying all account freeze events from the database by user ID. + GetAll(ctx context.Context, userID uuid.UUID) (*AccountFreezeEvent, *AccountFreezeEvent, error) // DeleteAllByUserID is a method for deleting all account freeze events from the database by user ID. DeleteAllByUserID(ctx context.Context, userID uuid.UUID) error } @@ -189,3 +191,27 @@ func (s *AccountFreezeService) UnfreezeUser(ctx context.Context, userID uuid.UUI return ErrAccountFreeze.Wrap(s.freezeEventsDB.DeleteAllByUserID(ctx, userID)) } + +// WarnUser adds a warning event to the freeze events table. +func (s *AccountFreezeService) WarnUser(ctx context.Context, userID uuid.UUID) (err error) { + defer mon.Task()(&ctx)(&err) + + _, err = s.freezeEventsDB.Upsert(ctx, &AccountFreezeEvent{ + UserID: userID, + Type: Warning, + }) + + return ErrAccountFreeze.Wrap(err) +} + +// GetAll returns all events for a user. +func (s *AccountFreezeService) GetAll(ctx context.Context, userID uuid.UUID) (freeze *AccountFreezeEvent, warning *AccountFreezeEvent, err error) { + defer mon.Task()(&ctx)(&err) + + freeze, warning, err = s.freezeEventsDB.GetAll(ctx, userID) + if err != nil { + return nil, nil, ErrAccountFreeze.Wrap(err) + } + + return freeze, warning, nil +} diff --git a/satellite/core.go b/satellite/core.go index ee1f56431..851b865db 100644 --- a/satellite/core.go +++ b/satellite/core.go @@ -30,7 +30,9 @@ import ( "storj.io/storj/satellite/accounting/rollup" "storj.io/storj/satellite/accounting/rolluparchive" "storj.io/storj/satellite/accounting/tally" + "storj.io/storj/satellite/analytics" "storj.io/storj/satellite/audit" + "storj.io/storj/satellite/console" "storj.io/storj/satellite/console/consoleauth" "storj.io/storj/satellite/console/emailreminders" "storj.io/storj/satellite/gracefulexit" @@ -46,6 +48,7 @@ import ( "storj.io/storj/satellite/overlay/offlinenodes" "storj.io/storj/satellite/overlay/straynodes" "storj.io/storj/satellite/payments" + "storj.io/storj/satellite/payments/accountfreeze" "storj.io/storj/satellite/payments/billing" "storj.io/storj/satellite/payments/storjscan" "storj.io/storj/satellite/payments/stripecoinpayments" @@ -143,6 +146,7 @@ type Core struct { } Payments struct { + AccountFreeze *accountfreeze.Chore Accounts payments.Accounts BillingChore *billing.Chore StorjscanClient *storjscan.Client @@ -612,6 +616,26 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, }) } + { // setup account freeze + if config.AccountFreeze.Enabled { + peer.Payments.AccountFreeze = accountfreeze.NewChore( + peer.Log.Named("payments.accountfreeze:chore"), + peer.DB.StripeCoinPayments(), + peer.Payments.Accounts, + peer.DB.Console().Users(), + console.NewAccountFreezeService(db.Console().AccountFreezeEvents(), db.Console().Users(), db.Console().Projects()), + analytics.NewService(peer.Log.Named("analytics:service"), config.Analytics, config.Console.SatelliteName), + config.AccountFreeze, + ) + + peer.Services.Add(lifecycle.Item{ + Name: "accountfreeze:chore", + Run: peer.Payments.AccountFreeze.Run, + Close: peer.Payments.AccountFreeze.Close, + }) + } + } + { // setup graceful exit log := peer.Log.Named("gracefulexit") switch { diff --git a/satellite/payments/accountfreeze/chore.go b/satellite/payments/accountfreeze/chore.go new file mode 100644 index 000000000..c1f301d10 --- /dev/null +++ b/satellite/payments/accountfreeze/chore.go @@ -0,0 +1,135 @@ +// Copyright (C) 2023 Storj Labs, Inc. +// See LICENSE for copying information. + +package accountfreeze + +import ( + "context" + "time" + + "github.com/spacemonkeygo/monkit/v3" + "github.com/zeebo/errs" + "go.uber.org/zap" + + "storj.io/common/sync2" + "storj.io/storj/satellite/analytics" + "storj.io/storj/satellite/console" + "storj.io/storj/satellite/payments" + "storj.io/storj/satellite/payments/stripecoinpayments" +) + +var ( + // Error is the standard error class for automatic freeze errors. + Error = errs.Class("account-freeze-chore") + mon = monkit.Package() +) + +// Config contains configurable values for account freeze chore. +type Config struct { + Enabled bool `help:"whether to run this chore." default:"false"` + Interval time.Duration `help:"How often to run this chore, which is how often unpaid invoices are checked." default:"24h"` + GracePeriod time.Duration `help:"How long to wait between a warning event and freezing an account." default:"720h"` + PriceThreshold int64 `help:"The failed invoice amount beyond which an account will not be frozen" default:"2000"` +} + +// Chore is a chore that checks for unpaid invoices and potentially freezes corresponding accounts. +type Chore struct { + log *zap.Logger + freezeService *console.AccountFreezeService + analytics *analytics.Service + usersDB console.Users + payments payments.Accounts + accounts stripecoinpayments.DB + config Config + nowFn func() time.Time + Loop *sync2.Cycle +} + +// NewChore is a constructor for Chore. +func NewChore(log *zap.Logger, accounts stripecoinpayments.DB, payments payments.Accounts, usersDB console.Users, freezeService *console.AccountFreezeService, analytics *analytics.Service, config Config) *Chore { + return &Chore{ + log: log, + freezeService: freezeService, + analytics: analytics, + usersDB: usersDB, + accounts: accounts, + config: config, + payments: payments, + nowFn: time.Now, + Loop: sync2.NewCycle(config.Interval), + } +} + +// Run runs the chore. +func (chore *Chore) Run(ctx context.Context) (err error) { + defer mon.Task()(&ctx)(&err) + return chore.Loop.Run(ctx, func(ctx context.Context) (err error) { + + invoices, err := chore.payments.Invoices().ListFailed(ctx) + if err != nil { + chore.log.Error("Could not list invoices", zap.Error(Error.Wrap(err))) + return nil + } + + for _, invoice := range invoices { + userID, err := chore.accounts.Customers().GetUserID(ctx, invoice.CustomerID) + if err != nil { + chore.log.Error("Could not get userID", zap.String("invoice", invoice.ID), zap.Error(Error.Wrap(err))) + continue + } + + user, err := chore.usersDB.Get(ctx, userID) + if err != nil { + chore.log.Error("Could not get user", zap.String("invoice", invoice.ID), zap.Error(Error.Wrap(err))) + continue + } + + if invoice.Amount > chore.config.PriceThreshold { + chore.analytics.TrackLargeUnpaidInvoice(invoice.ID, userID, user.Email) + continue + } + + freeze, warning, err := chore.freezeService.GetAll(ctx, userID) + if err != nil { + chore.log.Error("Could not check freeze status", zap.String("invoice", invoice.ID), zap.Error(Error.Wrap(err))) + continue + } + if freeze != nil { + // account already frozen + continue + } + + if warning == nil { + err = chore.freezeService.WarnUser(ctx, userID) + if err != nil { + chore.log.Error("Could not add warning event", zap.String("invoice", invoice.ID), zap.Error(Error.Wrap(err))) + continue + } + chore.analytics.TrackAccountFreezeWarning(userID, user.Email) + continue + } + + if chore.nowFn().Sub(warning.CreatedAt) > chore.config.GracePeriod { + err = chore.freezeService.FreezeUser(ctx, userID) + if err != nil { + chore.log.Error("Could not freeze account", zap.String("invoice", invoice.ID), zap.Error(Error.Wrap(err))) + continue + } + chore.analytics.TrackAccountFrozen(userID, user.Email) + } + } + + return nil + }) +} + +// TestSetNow sets nowFn on chore for testing. +func (chore *Chore) TestSetNow(f func() time.Time) { + chore.nowFn = f +} + +// Close closes the chore. +func (chore *Chore) Close() error { + chore.Loop.Close() + return nil +} diff --git a/satellite/payments/accountfreeze/chore_test.go b/satellite/payments/accountfreeze/chore_test.go new file mode 100644 index 000000000..0c75b8485 --- /dev/null +++ b/satellite/payments/accountfreeze/chore_test.go @@ -0,0 +1,155 @@ +// Copyright (C) 2023 Storj Labs, Inc. +// See LICENSE for copying information. + +package accountfreeze_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stripe/stripe-go/v72" + "go.uber.org/zap" + + "storj.io/common/testcontext" + "storj.io/storj/private/testplanet" + "storj.io/storj/satellite" + "storj.io/storj/satellite/console" + "storj.io/storj/satellite/payments/stripecoinpayments" +) + +func TestAutoFreezeChore(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.AccountFreeze.Enabled = true + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + sat := planet.Satellites[0] + stripeClient := sat.API.Payments.StripeClient + invoicesDB := sat.Core.Payments.Accounts.Invoices() + customerDB := sat.Core.DB.StripeCoinPayments().Customers() + usersDB := sat.DB.Console().Users() + projectsDB := sat.DB.Console().Projects() + service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB) + chore := sat.Core.Payments.AccountFreeze + + user, err := sat.AddUser(ctx, console.CreateUser{ + FullName: "Test User", + Email: "user@mail.test", + }, 1) + require.NoError(t, err) + + cus1, err := customerDB.GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + amount := int64(100) + curr := string(stripe.CurrencyUSD) + + t.Run("No freeze event for paid invoice", func(t *testing.T) { + item, err := stripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Amount: &amount, + Currency: &curr, + Customer: &cus1, + }) + require.NoError(t, err) + + items := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + items = append(items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &item.ID, + Amount: &amount, + Currency: &curr, + }) + inv, err := stripeClient.Invoices().New(&stripe.InvoiceParams{ + Customer: &cus1, + InvoiceItems: items, + }) + require.NoError(t, err) + + inv, err = stripeClient.Invoices().Pay(inv.ID, &stripe.InvoicePayParams{}) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusPaid, inv.Status) + + failed, err := invoicesDB.ListFailed(ctx) + require.NoError(t, err) + require.Equal(t, 0, len(failed)) + + chore.Loop.TriggerWait() + + // user should not be warned or frozen. + freeze, warning, err := service.GetAll(ctx, user.ID) + require.NoError(t, err) + require.Nil(t, warning) + require.Nil(t, freeze) + + // forward date to after the grace period + chore.TestSetNow(func() time.Time { + return time.Now().AddDate(0, 0, 50) + }) + chore.Loop.TriggerWait() + + // user should still not be warned or frozen. + freeze, warning, err = service.GetAll(ctx, user.ID) + require.NoError(t, err) + require.Nil(t, freeze) + require.Nil(t, warning) + }) + + t.Run("Freeze event for failed invoice", func(t *testing.T) { + // reset chore clock + chore.TestSetNow(time.Now) + + item, err := stripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Amount: &amount, + Currency: &curr, + Customer: &cus1, + }) + require.NoError(t, err) + + items := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 1) + items = append(items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &item.ID, + Amount: &amount, + Currency: &curr, + }) + inv, err := stripeClient.Invoices().New(&stripe.InvoiceParams{ + Customer: &cus1, + InvoiceItems: items, + }) + require.NoError(t, err) + + paymentMethod := stripecoinpayments.MockInvoicesPayFailure + inv, err = stripeClient.Invoices().Pay(inv.ID, &stripe.InvoicePayParams{ + PaymentMethod: &paymentMethod, + }) + require.Error(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + + failed, err := invoicesDB.ListFailed(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(failed)) + require.Equal(t, inv.ID, failed[0].ID) + + chore.Loop.TriggerWait() + + // user should be warned the first time + freeze, warning, err := service.GetAll(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, warning) + require.Nil(t, freeze) + + chore.TestSetNow(func() time.Time { + // current date is now after grace period + return time.Now().AddDate(0, 0, 50) + }) + chore.Loop.TriggerWait() + + // user should be frozen this time around + freeze, _, err = service.GetAll(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, freeze) + }) + }) +} diff --git a/satellite/payments/invoices.go b/satellite/payments/invoices.go index 2e1b26772..709a95300 100644 --- a/satellite/payments/invoices.go +++ b/satellite/payments/invoices.go @@ -33,6 +33,8 @@ type Invoices interface { Pay(ctx context.Context, invoiceID, paymentMethodID string) (*Invoice, error) // List returns a list of invoices for a given payment account. List(ctx context.Context, userID uuid.UUID) ([]Invoice, error) + // ListFailed returns a list of failed invoices. + ListFailed(ctx context.Context) ([]Invoice, error) // ListWithDiscounts returns a list of invoices and coupon usages for a given payment account. ListWithDiscounts(ctx context.Context, userID uuid.UUID) ([]Invoice, []CouponUsage, error) // CheckPendingItems returns if pending invoice items for a given payment account exist. @@ -46,6 +48,7 @@ type Invoices interface { // Invoice holds all public information about invoice. type Invoice struct { ID string `json:"id"` + CustomerID string `json:"-"` Description string `json:"description"` Amount int64 `json:"amount"` Status string `json:"status"` diff --git a/satellite/payments/stripecoinpayments/customers.go b/satellite/payments/stripecoinpayments/customers.go index 9362f85f2..4193bba9e 100644 --- a/satellite/payments/stripecoinpayments/customers.go +++ b/satellite/payments/stripecoinpayments/customers.go @@ -22,6 +22,8 @@ type CustomersDB interface { Insert(ctx context.Context, userID uuid.UUID, customerID string) error // GetCustomerID return stripe customers id. GetCustomerID(ctx context.Context, userID uuid.UUID) (string, error) + // GetUserID return userID given stripe customer id. + GetUserID(ctx context.Context, customerID string) (uuid.UUID, error) // List returns page with customers ids created before specified date. List(ctx context.Context, offset int64, limit int, before time.Time) (CustomersPage, error) diff --git a/satellite/payments/stripecoinpayments/invoices.go b/satellite/payments/stripecoinpayments/invoices.go index ccd14aab6..5f82e1d82 100644 --- a/satellite/payments/stripecoinpayments/invoices.go +++ b/satellite/payments/stripecoinpayments/invoices.go @@ -9,6 +9,7 @@ import ( "github.com/stripe/stripe-go/v72" "github.com/zeebo/errs" + "go.uber.org/zap" "storj.io/common/uuid" "storj.io/storj/satellite/payments" @@ -140,6 +141,7 @@ func (invoices *invoices) List(ctx context.Context, userID uuid.UUID) (invoicesL invoicesList = append(invoicesList, payments.Invoice{ ID: stripeInvoice.ID, + CustomerID: customerID, Description: stripeInvoice.Description, Amount: total, Status: convertStatus(stripeInvoice.Status), @@ -155,6 +157,47 @@ func (invoices *invoices) List(ctx context.Context, userID uuid.UUID) (invoicesL return invoicesList, nil } +func (invoices *invoices) ListFailed(ctx context.Context) (invoicesList []payments.Invoice, err error) { + defer mon.Task()(&ctx)(&err) + + status := string(stripe.InvoiceStatusOpen) + params := &stripe.InvoiceListParams{ + Status: &status, + } + + invoicesIterator := invoices.service.stripeClient.Invoices().List(params) + for invoicesIterator.Next() { + stripeInvoice := invoicesIterator.Invoice() + + total := stripeInvoice.Total + for _, line := range stripeInvoice.Lines.Data { + // If amount is negative, this is a coupon or a credit line item. + // Add them to the total. + if line.Amount < 0 { + total -= line.Amount + } + } + + if invoices.isInvoiceFailed(stripeInvoice) { + invoicesList = append(invoicesList, payments.Invoice{ + ID: stripeInvoice.ID, + CustomerID: stripeInvoice.Customer.ID, + Description: stripeInvoice.Description, + Amount: total, + Status: string(stripeInvoice.Status), + Link: stripeInvoice.InvoicePDF, + Start: time.Unix(stripeInvoice.PeriodStart, 0), + }) + } + } + + if err = invoicesIterator.Err(); err != nil { + return nil, Error.Wrap(err) + } + + return invoicesList, nil +} + // ListWithDiscounts returns a list of invoices and coupon usages for a given payment account. func (invoices *invoices) ListWithDiscounts(ctx context.Context, userID uuid.UUID) (invoicesList []payments.Invoice, couponUsages []payments.CouponUsage, err error) { defer mon.Task()(&ctx, userID)(&err) @@ -184,6 +227,7 @@ func (invoices *invoices) ListWithDiscounts(ctx context.Context, userID uuid.UUI invoicesList = append(invoicesList, payments.Invoice{ ID: stripeInvoice.ID, + CustomerID: customerID, Description: stripeInvoice.Description, Amount: total, Status: convertStatus(stripeInvoice.Status), @@ -290,3 +334,22 @@ func convertStatus(stripestatus stripe.InvoiceStatus) string { } return status } + +// isInvoiceFailed returns whether an invoice has failed. +func (invoices *invoices) isInvoiceFailed(invoice *stripe.Invoice) bool { + if invoice.DueDate > 0 { + // https://github.com/storj/storj/blob/77bf88e916a10dc898ebb594eafac667ed4426cd/satellite/payments/stripecoinpayments/service.go#L781-L787 + invoices.service.log.Info("Skipping invoice marked for manual payment", + zap.String("id", invoice.ID), + zap.String("number", invoice.Number), + zap.String("customer", invoice.Customer.ID)) + return false + } + // https://stripe.com/docs/api/invoices/retrieve + if invoice.NextPaymentAttempt > 0 { + // stripe will automatically retry collecting payment. + return false + } + + return true +} diff --git a/satellite/payments/stripecoinpayments/service_test.go b/satellite/payments/stripecoinpayments/service_test.go index 2665e091f..c54aaf985 100644 --- a/satellite/payments/stripecoinpayments/service_test.go +++ b/satellite/payments/stripecoinpayments/service_test.go @@ -658,12 +658,21 @@ func TestPayInvoicesSkipDue(t *testing.T) { Customer: &cus1, }) require.NoError(t, err) + + inv, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv.ID, &stripe.InvoiceFinalizeParams{}) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv.Status) + invWithDue, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ Customer: &cus2, DueDate: &due, }) require.NoError(t, err) + invWithDue, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(invWithDue.ID, &stripe.InvoiceFinalizeParams{}) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, invWithDue.Status) + err = satellite.API.Payments.StripeService.PayInvoices(ctx, time.Time{}) require.NoError(t, err) @@ -675,7 +684,7 @@ func TestPayInvoicesSkipDue(t *testing.T) { } // when due date is set invoice should not be paid if i.ID == invWithDue.ID { - require.Equal(t, stripe.InvoiceStatusDraft, i.Status) + require.Equal(t, stripe.InvoiceStatusOpen, i.Status) } } }) diff --git a/satellite/payments/stripecoinpayments/stripemock.go b/satellite/payments/stripecoinpayments/stripemock.go index 4d9774eaa..8d854ff35 100644 --- a/satellite/payments/stripecoinpayments/stripemock.go +++ b/satellite/payments/stripecoinpayments/stripemock.go @@ -511,6 +511,14 @@ func (m *mockInvoices) New(params *stripe.InvoiceParams) (*stripe.Invoice, error due = *params.DueDate } + lineData := make([]*stripe.InvoiceLine, 0, len(params.InvoiceItems)) + for _, item := range params.InvoiceItems { + lineData = append(lineData, &stripe.InvoiceLine{ + InvoiceItem: *item.InvoiceItem, + Amount: *item.Amount, + }) + } + var desc string if params.Description != nil { if *params.Description == MockInvoicesNewFailure { @@ -525,6 +533,9 @@ func (m *mockInvoices) New(params *stripe.InvoiceParams) (*stripe.Invoice, error DueDate: due, Status: stripe.InvoiceStatusDraft, Description: desc, + Lines: &stripe.InvoiceLineList{ + Data: lineData, + }, } m.invoices[*params.Customer] = append(m.invoices[*params.Customer], invoice) @@ -548,7 +559,16 @@ func (m *mockInvoices) List(listParams *stripe.InvoiceListParams) *invoice.Iter lc := newListContainer(listMeta) query := stripe.Query(func(*stripe.Params, *form.Values) (ret []interface{}, _ stripe.ListContainer, _ error) { - if listParams.Customer == nil { + if listParams.Customer == nil && listParams.Status != nil { + // filter by status + for _, invoices := range m.invoices { + for _, inv := range invoices { + if inv.Status == stripe.InvoiceStatus(*listParams.Status) { + ret = append(ret, inv) + } + } + } + } else if listParams.Customer == nil { for _, invoices := range m.invoices { for _, invoice := range invoices { ret = append(ret, invoice) @@ -577,8 +597,18 @@ func (m *mockInvoices) Update(id string, params *stripe.InvoiceParams) (invoice return nil, errors.New("invoice not found") } +// FinalizeInvoice forwards the invoice's status from draft to open. func (m *mockInvoices) FinalizeInvoice(id string, params *stripe.InvoiceFinalizeParams) (*stripe.Invoice, error) { - return nil, nil + for _, invoices := range m.invoices { + for i, invoice := range invoices { + if invoice.ID == id && invoice.Status == stripe.InvoiceStatusDraft { + invoice.Status = stripe.InvoiceStatusOpen + m.invoices[invoice.Customer.ID][i].Status = stripe.InvoiceStatusOpen + return invoice, nil + } + } + } + return nil, &stripe.Error{} } func (m *mockInvoices) Pay(id string, params *stripe.InvoicePayParams) (*stripe.Invoice, error) { diff --git a/satellite/peer.go b/satellite/peer.go index ef61a1b3e..e6dc51ecd 100644 --- a/satellite/peer.go +++ b/satellite/peer.go @@ -57,6 +57,7 @@ import ( "storj.io/storj/satellite/overlay" "storj.io/storj/satellite/overlay/offlinenodes" "storj.io/storj/satellite/overlay/straynodes" + "storj.io/storj/satellite/payments/accountfreeze" "storj.io/storj/satellite/payments/billing" "storj.io/storj/satellite/payments/paymentsconfig" "storj.io/storj/satellite/payments/storjscan" @@ -201,6 +202,8 @@ type Config struct { ConsoleAuth consoleauth.Config EmailReminders emailreminders.Config + AccountFreeze accountfreeze.Config + Version version_checker.Config GracefulExit gracefulexit.Config diff --git a/satellite/satellitedb/accountfreezeevents.go b/satellite/satellitedb/accountfreezeevents.go index 473c3c911..480033fbe 100644 --- a/satellite/satellitedb/accountfreezeevents.go +++ b/satellite/satellitedb/accountfreezeevents.go @@ -64,6 +64,36 @@ func (events *accountFreezeEvents) Get(ctx context.Context, userID uuid.UUID, ev return fromDBXAccountFreezeEvent(dbxEvent) } +// GetAll is a method for querying all account freeze events from the database by user ID. +func (events *accountFreezeEvents) GetAll(ctx context.Context, userID uuid.UUID) (freeze *console.AccountFreezeEvent, warning *console.AccountFreezeEvent, err error) { + defer mon.Task()(&ctx)(&err) + + // dbxEvents will have a max length of 2. + // because there's at most 1 instance each of 2 types of events for a user. + dbxEvents, err := events.db.All_AccountFreezeEvent_By_UserId(ctx, + dbx.AccountFreezeEvent_UserId(userID.Bytes()), + ) + if err != nil { + return nil, nil, err + } + + for _, event := range dbxEvents { + if console.AccountFreezeEventType(event.Event) == console.Freeze { + freeze, err = fromDBXAccountFreezeEvent(event) + if err != nil { + return nil, nil, err + } + continue + } + warning, err = fromDBXAccountFreezeEvent(event) + if err != nil { + return nil, nil, err + } + } + + return freeze, warning, nil +} + // DeleteAllByUserID is a method for deleting all account freeze events from the database by user ID. func (events *accountFreezeEvents) DeleteAllByUserID(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/satellitedb/customers.go b/satellite/satellitedb/customers.go index f79875b9c..776c0beb7 100644 --- a/satellite/satellitedb/customers.go +++ b/satellite/satellitedb/customers.go @@ -58,6 +58,22 @@ func (customers *customers) GetCustomerID(ctx context.Context, userID uuid.UUID) return idRow.CustomerId, nil } +// GetUserID return userID given stripe customer id. +func (customers *customers) GetUserID(ctx context.Context, customerID string) (_ uuid.UUID, err error) { + defer mon.Task()(&ctx)(&err) + + idRow, err := customers.db.Get_StripeCustomer_UserId_By_CustomerId(ctx, dbx.StripeCustomer_CustomerId(customerID)) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return uuid.UUID{}, stripecoinpayments.ErrNoCustomer + } + + return uuid.UUID{}, err + } + + return uuid.FromBytes(idRow.UserId) +} + // List returns paginated customers id list, with customers created before specified date. func (customers *customers) List(ctx context.Context, offset int64, limit int, before time.Time) (_ stripecoinpayments.CustomersPage, err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/satellitedb/dbx/billing.dbx b/satellite/satellitedb/dbx/billing.dbx index e2b4beeb3..6b715f3f2 100644 --- a/satellite/satellitedb/dbx/billing.dbx +++ b/satellite/satellitedb/dbx/billing.dbx @@ -17,6 +17,10 @@ read one ( select stripe_customer.customer_id where stripe_customer.user_id = ? ) +read one ( + select stripe_customer.user_id + where stripe_customer.customer_id = ? +) read limitoffset ( select stripe_customer where stripe_customer.created_at <= ? diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 1c0d7f5f5..3244b11d5 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -13756,6 +13756,28 @@ func (obj *pgxImpl) Get_StripeCustomer_CustomerId_By_UserId(ctx context.Context, } +func (obj *pgxImpl) Get_StripeCustomer_UserId_By_CustomerId(ctx context.Context, + stripe_customer_customer_id StripeCustomer_CustomerId_Field) ( + row *UserId_Row, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT stripe_customers.user_id FROM stripe_customers WHERE stripe_customers.customer_id = ?") + + var __values []interface{} + __values = append(__values, stripe_customer_customer_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + row = &UserId_Row{} + err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.UserId) + if err != nil { + return (*UserId_Row)(nil), obj.makeErr(err) + } + return row, nil + +} + func (obj *pgxImpl) Limited_StripeCustomer_By_CreatedAt_LessOrEqual_OrderBy_Desc_CreatedAt(ctx context.Context, stripe_customer_created_at_less_or_equal StripeCustomer_CreatedAt_Field, limit int, offset int64) ( @@ -16286,6 +16308,51 @@ func (obj *pgxImpl) Get_AccountFreezeEvent_By_UserId_And_Event(ctx context.Conte } +func (obj *pgxImpl) All_AccountFreezeEvent_By_UserId(ctx context.Context, + account_freeze_event_user_id AccountFreezeEvent_UserId_Field) ( + rows []*AccountFreezeEvent, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT account_freeze_events.user_id, account_freeze_events.event, account_freeze_events.limits, account_freeze_events.created_at FROM account_freeze_events WHERE account_freeze_events.user_id = ?") + + var __values []interface{} + __values = append(__values, account_freeze_event_user_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*AccountFreezeEvent, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + account_freeze_event := &AccountFreezeEvent{} + err = __rows.Scan(&account_freeze_event.UserId, &account_freeze_event.Event, &account_freeze_event.Limits, &account_freeze_event.CreatedAt) + if err != nil { + return nil, err + } + rows = append(rows, account_freeze_event) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxImpl) Get_UserSettings_By_UserId(ctx context.Context, user_settings_user_id UserSettings_UserId_Field) ( user_settings *UserSettings, err error) { @@ -21269,6 +21336,28 @@ func (obj *pgxcockroachImpl) Get_StripeCustomer_CustomerId_By_UserId(ctx context } +func (obj *pgxcockroachImpl) Get_StripeCustomer_UserId_By_CustomerId(ctx context.Context, + stripe_customer_customer_id StripeCustomer_CustomerId_Field) ( + row *UserId_Row, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT stripe_customers.user_id FROM stripe_customers WHERE stripe_customers.customer_id = ?") + + var __values []interface{} + __values = append(__values, stripe_customer_customer_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + row = &UserId_Row{} + err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.UserId) + if err != nil { + return (*UserId_Row)(nil), obj.makeErr(err) + } + return row, nil + +} + func (obj *pgxcockroachImpl) Limited_StripeCustomer_By_CreatedAt_LessOrEqual_OrderBy_Desc_CreatedAt(ctx context.Context, stripe_customer_created_at_less_or_equal StripeCustomer_CreatedAt_Field, limit int, offset int64) ( @@ -23799,6 +23888,51 @@ func (obj *pgxcockroachImpl) Get_AccountFreezeEvent_By_UserId_And_Event(ctx cont } +func (obj *pgxcockroachImpl) All_AccountFreezeEvent_By_UserId(ctx context.Context, + account_freeze_event_user_id AccountFreezeEvent_UserId_Field) ( + rows []*AccountFreezeEvent, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT account_freeze_events.user_id, account_freeze_events.event, account_freeze_events.limits, account_freeze_events.created_at FROM account_freeze_events WHERE account_freeze_events.user_id = ?") + + var __values []interface{} + __values = append(__values, account_freeze_event_user_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*AccountFreezeEvent, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + account_freeze_event := &AccountFreezeEvent{} + err = __rows.Scan(&account_freeze_event.UserId, &account_freeze_event.Event, &account_freeze_event.Limits, &account_freeze_event.CreatedAt) + if err != nil { + return nil, err + } + rows = append(rows, account_freeze_event) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxcockroachImpl) Get_UserSettings_By_UserId(ctx context.Context, user_settings_user_id UserSettings_UserId_Field) ( user_settings *UserSettings, err error) { @@ -26803,6 +26937,16 @@ func (rx *Rx) Rollback() (err error) { return err } +func (rx *Rx) All_AccountFreezeEvent_By_UserId(ctx context.Context, + account_freeze_event_user_id AccountFreezeEvent_UserId_Field) ( + rows []*AccountFreezeEvent, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.All_AccountFreezeEvent_By_UserId(ctx, account_freeze_event_user_id) +} + func (rx *Rx) All_BillingTransaction_By_UserId_OrderBy_Desc_Timestamp(ctx context.Context, billing_transaction_user_id BillingTransaction_UserId_Field) ( rows []*BillingTransaction, err error) { @@ -28082,6 +28226,16 @@ func (rx *Rx) Get_StripeCustomer_CustomerId_By_UserId(ctx context.Context, return tx.Get_StripeCustomer_CustomerId_By_UserId(ctx, stripe_customer_user_id) } +func (rx *Rx) Get_StripeCustomer_UserId_By_CustomerId(ctx context.Context, + stripe_customer_customer_id StripeCustomer_CustomerId_Field) ( + row *UserId_Row, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Get_StripeCustomer_UserId_By_CustomerId(ctx, stripe_customer_customer_id) +} + func (rx *Rx) Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_PeriodStart_And_PeriodEnd(ctx context.Context, stripecoinpayments_invoice_project_record_project_id StripecoinpaymentsInvoiceProjectRecord_ProjectId_Field, stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, @@ -28720,6 +28874,10 @@ func (rx *Rx) Update_WebappSession_By_Id(ctx context.Context, } type Methods interface { + All_AccountFreezeEvent_By_UserId(ctx context.Context, + account_freeze_event_user_id AccountFreezeEvent_UserId_Field) ( + rows []*AccountFreezeEvent, err error) + All_BillingTransaction_By_UserId_OrderBy_Desc_Timestamp(ctx context.Context, billing_transaction_user_id BillingTransaction_UserId_Field) ( rows []*BillingTransaction, err error) @@ -29298,6 +29456,10 @@ type Methods interface { stripe_customer_user_id StripeCustomer_UserId_Field) ( row *CustomerId_Row, err error) + Get_StripeCustomer_UserId_By_CustomerId(ctx context.Context, + stripe_customer_customer_id StripeCustomer_CustomerId_Field) ( + row *UserId_Row, err error) + Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_PeriodStart_And_PeriodEnd(ctx context.Context, stripecoinpayments_invoice_project_record_project_id StripecoinpaymentsInvoiceProjectRecord_ProjectId_Field, stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, diff --git a/satellite/satellitedb/dbx/user.dbx b/satellite/satellitedb/dbx/user.dbx index 24302b172..118b61ce3 100644 --- a/satellite/satellitedb/dbx/user.dbx +++ b/satellite/satellitedb/dbx/user.dbx @@ -208,6 +208,11 @@ read one ( where account_freeze_event.event = ? ) +read all ( + select account_freeze_event + where account_freeze_event.user_id = ? +) + update account_freeze_event ( where account_freeze_event.user_id = ? where account_freeze_event.event = ? diff --git a/scripts/testdata/satellite-config.yaml.lock b/scripts/testdata/satellite-config.yaml.lock index eea4cec14..d1d30745e 100755 --- a/scripts/testdata/satellite-config.yaml.lock +++ b/scripts/testdata/satellite-config.yaml.lock @@ -1,3 +1,15 @@ +# whether to run this chore. +# account-freeze.enabled: false + +# How long to wait between a warning event and freezing an account. +# account-freeze.grace-period: 720h0m0s + +# How often to run this chore, which is how often unpaid invoices are checked. +# account-freeze.interval: 24h0m0s + +# The failed invoice amount beyond which an account will not be frozen +# account-freeze.price-threshold: 2000 + # admin peer http listening address # admin.address: ""