satellite/payments: fix account freeze chore race condition

This change fixes an issue where a formerly warned/frozen user will be
warned again even though they have made payment for the invoice that got
them frozen in the first place. A payment status check is now made
right before a warn/freeze event to make sure the invoice hasn't been
paid already.

Issue: https://github.com/storj/storj/issues/5931

Change-Id: I3f6ac1e224f40107d58dc8f7bdbce58bbbea0196
This commit is contained in:
Wilfred Asomani 2023-06-08 11:00:29 +00:00 committed by Wilfred Asomani
parent 6b65b7e7d0
commit 4a49bc4b65
6 changed files with 129 additions and 44 deletions

View File

@ -78,6 +78,14 @@ func (chore *Chore) Run(ctx context.Context) (err error) {
warnedMap := make(map[uuid.UUID]struct{})
bypassedMap := make(map[uuid.UUID]struct{})
checkInvPaid := func(invID string) (bool, error) {
inv, err := chore.payments.Invoices().Get(ctx, invID)
if err != nil {
return false, err
}
return inv.Status == payments.InvoiceStatusPaid, nil
}
for _, invoice := range invoices {
userID, err := chore.accounts.Customers().GetUserID(ctx, invoice.CustomerID)
if err != nil {
@ -88,85 +96,87 @@ func (chore *Chore) Run(ctx context.Context) (err error) {
)
continue
}
userMap[userID] = struct{}{}
user, err := chore.usersDB.Get(ctx, userID)
if err != nil {
chore.log.Error("Could not get user",
debugLog := func(message string) {
chore.log.Debug(message,
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
)
}
errorLog := func(message string, err error) {
chore.log.Error(message,
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
zap.Error(Error.Wrap(err)),
)
}
userMap[userID] = struct{}{}
user, err := chore.usersDB.Get(ctx, userID)
if err != nil {
errorLog("Could not get user", err)
continue
}
if invoice.Amount > chore.config.PriceThreshold {
bypassedMap[userID] = struct{}{}
chore.log.Debug("amount due over threshold",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
)
debugLog("Ignoring invoice; amount exceeds threshold")
chore.analytics.TrackLargeUnpaidInvoice(invoice.ID, userID, user.Email)
continue
}
freeze, warning, err := chore.freezeService.GetAll(ctx, userID)
if err != nil {
chore.log.Error("Could not check freeze status",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
zap.Error(Error.Wrap(err)),
)
errorLog("Could not get freeze status", err)
continue
}
if freeze != nil {
chore.log.Debug("Ignoring invoice; account already frozen",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
)
debugLog("Ignoring invoice; account already frozen")
continue
}
if warning == nil {
err = chore.freezeService.WarnUser(ctx, userID)
// check if the invoice has been paid by the time the chore gets here.
isPaid, err := checkInvPaid(invoice.ID)
if err != nil {
chore.log.Error("Could not add warning event",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
zap.Error(Error.Wrap(err)),
)
errorLog("Could not verify invoice status", err)
continue
}
chore.log.Debug("user warned",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
)
if isPaid {
debugLog("Ignoring invoice; payment already made")
continue
}
err = chore.freezeService.WarnUser(ctx, userID)
if err != nil {
errorLog("Could not add warning event", err)
continue
}
debugLog("user warned")
warnedMap[userID] = struct{}{}
continue
}
if chore.nowFn().Sub(warning.CreatedAt) > chore.config.GracePeriod {
err = chore.freezeService.FreezeUser(ctx, userID)
// check if the invoice has been paid by the time the chore gets here.
isPaid, err := checkInvPaid(invoice.ID)
if err != nil {
chore.log.Error("Could not freeze account",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
zap.Error(Error.Wrap(err)),
)
errorLog("Could not verify invoice status", err)
continue
}
chore.log.Debug("user frozen",
zap.String("invoiceID", invoice.ID),
zap.String("customerID", invoice.CustomerID),
zap.Any("userID", userID),
)
if isPaid {
debugLog("Ignoring invoice; payment already made")
continue
}
err = chore.freezeService.FreezeUser(ctx, userID)
if err != nil {
errorLog("Could not freeze account", err)
continue
}
debugLog("user frozen")
frozenMap[userID] = struct{}{}
}
}

View File

@ -29,6 +29,8 @@ const (
type Invoices interface {
// Create creates an invoice with price and description.
Create(ctx context.Context, userID uuid.UUID, price int64, desc string) (*Invoice, error)
// Get returns an invoice by invoiceID.
Get(ctx context.Context, invoiceID string) (*Invoice, error)
// Pay pays an invoice.
Pay(ctx context.Context, invoiceID, paymentMethodID string) (*Invoice, error)
// List returns a list of invoices for a given payment account.

View File

@ -63,6 +63,7 @@ type Invoices interface {
FinalizeInvoice(id string, params *stripe.InvoiceFinalizeParams) (*stripe.Invoice, error)
Pay(id string, params *stripe.InvoicePayParams) (*stripe.Invoice, error)
Del(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error)
Get(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error)
}
// InvoiceItems Stripe InvoiceItems interface.

View File

@ -75,6 +75,39 @@ func (invoices *invoices) Pay(ctx context.Context, invoiceID, paymentMethodID st
}, nil
}
func (invoices *invoices) Get(ctx context.Context, invoiceID string) (*payments.Invoice, error) {
params := &stripe.InvoiceParams{
Params: stripe.Params{
Context: ctx,
},
}
inv, err := invoices.service.stripeClient.Invoices().Get(invoiceID, params)
if err != nil {
return nil, Error.Wrap(err)
}
total := inv.Total
if inv.Lines != nil {
for _, line := range inv.Lines.Data {
// If amount is negative, this is a coupon or a credit line item.
// Add them to the total.
if line.Amount < 0 {
total -= line.Amount
}
}
}
return &payments.Invoice{
ID: inv.ID,
CustomerID: inv.Customer.ID,
Description: inv.Description,
Amount: total,
Status: convertStatus(inv.Status),
Link: inv.InvoicePDF,
Start: time.Unix(inv.PeriodStart, 0),
}, nil
}
// AttemptPayOverdueInvoices attempts to pay a user's open, overdue invoices.
func (invoices *invoices) AttemptPayOverdueInvoices(ctx context.Context, userID uuid.UUID) (err error) {
customerID, err := invoices.service.db.Customers().GetCustomerID(ctx, userID)

View File

@ -58,5 +58,16 @@ func TestInvoices(t *testing.T) {
require.Error(t, err)
require.Nil(t, confirmedPI)
})
t.Run("Create and Get success", func(t *testing.T) {
pi, err := satellite.API.Payments.Accounts.Invoices().Create(ctx, userID, price, desc)
require.NoError(t, err)
require.NotNil(t, pi)
pi2, err := satellite.API.Payments.Accounts.Invoices().Get(ctx, pi.ID)
require.NoError(t, err)
require.Equal(t, pi.ID, pi2.ID)
require.Equal(t, pi.Status, pi2.Status)
require.Equal(t, pi.Amount, pi2.Amount)
})
})
}

View File

@ -661,6 +661,34 @@ func (m *mockInvoices) Del(id string, params *stripe.InvoiceParams) (*stripe.Inv
return nil, nil
}
func (m *mockInvoices) Get(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error) {
for _, invoices := range m.invoices {
for _, inv := range invoices {
if inv.ID == id {
items, ok := m.invoiceItems.items[inv.Customer.ID]
if ok {
amountDue := int64(0)
lineData := make([]*stripe.InvoiceLine, 0, len(params.InvoiceItems))
for _, item := range items {
if item.Invoice != inv {
continue
}
lineData = append(lineData, &stripe.InvoiceLine{
InvoiceItem: item.ID,
Amount: item.Amount,
})
amountDue += item.Amount
}
inv.Lines.Data = lineData
inv.Total = amountDue
}
return inv, nil
}
}
}
return nil, nil
}
type mockInvoiceItems struct {
root *mockStripeState
items map[string][]*stripe.InvoiceItem