satellite/payments: fix account freeze chore race condition
This change fixes an issue where a formerly warned/frozen user will be warned again even though they have made payment for the invoice that got them frozen in the first place. A payment status check is now made right before a warn/freeze event to make sure the invoice hasn't been paid already. Issue: https://github.com/storj/storj/issues/5931 Change-Id: I3f6ac1e224f40107d58dc8f7bdbce58bbbea0196
This commit is contained in:
parent
6b65b7e7d0
commit
4a49bc4b65
@ -78,6 +78,14 @@ func (chore *Chore) Run(ctx context.Context) (err error) {
|
||||
warnedMap := make(map[uuid.UUID]struct{})
|
||||
bypassedMap := make(map[uuid.UUID]struct{})
|
||||
|
||||
checkInvPaid := func(invID string) (bool, error) {
|
||||
inv, err := chore.payments.Invoices().Get(ctx, invID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return inv.Status == payments.InvoiceStatusPaid, nil
|
||||
}
|
||||
|
||||
for _, invoice := range invoices {
|
||||
userID, err := chore.accounts.Customers().GetUserID(ctx, invoice.CustomerID)
|
||||
if err != nil {
|
||||
@ -88,85 +96,87 @@ func (chore *Chore) Run(ctx context.Context) (err error) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
userMap[userID] = struct{}{}
|
||||
|
||||
user, err := chore.usersDB.Get(ctx, userID)
|
||||
if err != nil {
|
||||
chore.log.Error("Could not get user",
|
||||
debugLog := func(message string) {
|
||||
chore.log.Debug(message,
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
)
|
||||
}
|
||||
|
||||
errorLog := func(message string, err error) {
|
||||
chore.log.Error(message,
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
zap.Error(Error.Wrap(err)),
|
||||
)
|
||||
}
|
||||
|
||||
userMap[userID] = struct{}{}
|
||||
|
||||
user, err := chore.usersDB.Get(ctx, userID)
|
||||
if err != nil {
|
||||
errorLog("Could not get user", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if invoice.Amount > chore.config.PriceThreshold {
|
||||
bypassedMap[userID] = struct{}{}
|
||||
chore.log.Debug("amount due over threshold",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
)
|
||||
debugLog("Ignoring invoice; amount exceeds threshold")
|
||||
chore.analytics.TrackLargeUnpaidInvoice(invoice.ID, userID, user.Email)
|
||||
continue
|
||||
}
|
||||
|
||||
freeze, warning, err := chore.freezeService.GetAll(ctx, userID)
|
||||
if err != nil {
|
||||
chore.log.Error("Could not check freeze status",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
zap.Error(Error.Wrap(err)),
|
||||
)
|
||||
errorLog("Could not get freeze status", err)
|
||||
continue
|
||||
}
|
||||
if freeze != nil {
|
||||
chore.log.Debug("Ignoring invoice; account already frozen",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
)
|
||||
debugLog("Ignoring invoice; account already frozen")
|
||||
continue
|
||||
}
|
||||
|
||||
if warning == nil {
|
||||
err = chore.freezeService.WarnUser(ctx, userID)
|
||||
// check if the invoice has been paid by the time the chore gets here.
|
||||
isPaid, err := checkInvPaid(invoice.ID)
|
||||
if err != nil {
|
||||
chore.log.Error("Could not add warning event",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
zap.Error(Error.Wrap(err)),
|
||||
)
|
||||
errorLog("Could not verify invoice status", err)
|
||||
continue
|
||||
}
|
||||
chore.log.Debug("user warned",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
)
|
||||
if isPaid {
|
||||
debugLog("Ignoring invoice; payment already made")
|
||||
continue
|
||||
}
|
||||
err = chore.freezeService.WarnUser(ctx, userID)
|
||||
if err != nil {
|
||||
errorLog("Could not add warning event", err)
|
||||
continue
|
||||
}
|
||||
debugLog("user warned")
|
||||
warnedMap[userID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if chore.nowFn().Sub(warning.CreatedAt) > chore.config.GracePeriod {
|
||||
err = chore.freezeService.FreezeUser(ctx, userID)
|
||||
// check if the invoice has been paid by the time the chore gets here.
|
||||
isPaid, err := checkInvPaid(invoice.ID)
|
||||
if err != nil {
|
||||
chore.log.Error("Could not freeze account",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
zap.Error(Error.Wrap(err)),
|
||||
)
|
||||
errorLog("Could not verify invoice status", err)
|
||||
continue
|
||||
}
|
||||
chore.log.Debug("user frozen",
|
||||
zap.String("invoiceID", invoice.ID),
|
||||
zap.String("customerID", invoice.CustomerID),
|
||||
zap.Any("userID", userID),
|
||||
)
|
||||
if isPaid {
|
||||
debugLog("Ignoring invoice; payment already made")
|
||||
continue
|
||||
}
|
||||
err = chore.freezeService.FreezeUser(ctx, userID)
|
||||
if err != nil {
|
||||
errorLog("Could not freeze account", err)
|
||||
continue
|
||||
}
|
||||
debugLog("user frozen")
|
||||
frozenMap[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,8 @@ const (
|
||||
type Invoices interface {
|
||||
// Create creates an invoice with price and description.
|
||||
Create(ctx context.Context, userID uuid.UUID, price int64, desc string) (*Invoice, error)
|
||||
// Get returns an invoice by invoiceID.
|
||||
Get(ctx context.Context, invoiceID string) (*Invoice, error)
|
||||
// Pay pays an invoice.
|
||||
Pay(ctx context.Context, invoiceID, paymentMethodID string) (*Invoice, error)
|
||||
// List returns a list of invoices for a given payment account.
|
||||
|
@ -63,6 +63,7 @@ type Invoices interface {
|
||||
FinalizeInvoice(id string, params *stripe.InvoiceFinalizeParams) (*stripe.Invoice, error)
|
||||
Pay(id string, params *stripe.InvoicePayParams) (*stripe.Invoice, error)
|
||||
Del(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error)
|
||||
Get(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error)
|
||||
}
|
||||
|
||||
// InvoiceItems Stripe InvoiceItems interface.
|
||||
|
@ -75,6 +75,39 @@ func (invoices *invoices) Pay(ctx context.Context, invoiceID, paymentMethodID st
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (invoices *invoices) Get(ctx context.Context, invoiceID string) (*payments.Invoice, error) {
|
||||
params := &stripe.InvoiceParams{
|
||||
Params: stripe.Params{
|
||||
Context: ctx,
|
||||
},
|
||||
}
|
||||
inv, err := invoices.service.stripeClient.Invoices().Get(invoiceID, params)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
total := inv.Total
|
||||
if inv.Lines != nil {
|
||||
for _, line := range inv.Lines.Data {
|
||||
// If amount is negative, this is a coupon or a credit line item.
|
||||
// Add them to the total.
|
||||
if line.Amount < 0 {
|
||||
total -= line.Amount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &payments.Invoice{
|
||||
ID: inv.ID,
|
||||
CustomerID: inv.Customer.ID,
|
||||
Description: inv.Description,
|
||||
Amount: total,
|
||||
Status: convertStatus(inv.Status),
|
||||
Link: inv.InvoicePDF,
|
||||
Start: time.Unix(inv.PeriodStart, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AttemptPayOverdueInvoices attempts to pay a user's open, overdue invoices.
|
||||
func (invoices *invoices) AttemptPayOverdueInvoices(ctx context.Context, userID uuid.UUID) (err error) {
|
||||
customerID, err := invoices.service.db.Customers().GetCustomerID(ctx, userID)
|
||||
|
@ -58,5 +58,16 @@ func TestInvoices(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, confirmedPI)
|
||||
})
|
||||
t.Run("Create and Get success", func(t *testing.T) {
|
||||
pi, err := satellite.API.Payments.Accounts.Invoices().Create(ctx, userID, price, desc)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pi)
|
||||
|
||||
pi2, err := satellite.API.Payments.Accounts.Invoices().Get(ctx, pi.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pi.ID, pi2.ID)
|
||||
require.Equal(t, pi.Status, pi2.Status)
|
||||
require.Equal(t, pi.Amount, pi2.Amount)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -661,6 +661,34 @@ func (m *mockInvoices) Del(id string, params *stripe.InvoiceParams) (*stripe.Inv
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInvoices) Get(id string, params *stripe.InvoiceParams) (*stripe.Invoice, error) {
|
||||
for _, invoices := range m.invoices {
|
||||
for _, inv := range invoices {
|
||||
if inv.ID == id {
|
||||
items, ok := m.invoiceItems.items[inv.Customer.ID]
|
||||
if ok {
|
||||
amountDue := int64(0)
|
||||
lineData := make([]*stripe.InvoiceLine, 0, len(params.InvoiceItems))
|
||||
for _, item := range items {
|
||||
if item.Invoice != inv {
|
||||
continue
|
||||
}
|
||||
lineData = append(lineData, &stripe.InvoiceLine{
|
||||
InvoiceItem: item.ID,
|
||||
Amount: item.Amount,
|
||||
})
|
||||
amountDue += item.Amount
|
||||
}
|
||||
inv.Lines.Data = lineData
|
||||
inv.Total = amountDue
|
||||
}
|
||||
return inv, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockInvoiceItems struct {
|
||||
root *mockStripeState
|
||||
items map[string][]*stripe.InvoiceItem
|
||||
|
Loading…
Reference in New Issue
Block a user