satellite/payments/{billing,stripe}: handle pending invoice payments

Currently, pending invoice payments that are made using a users token
balance can get stuck in a pending state if the invoice is not able
to be paid appropriately in stripe. This change addresses these stuck
token invoice payments by attempting to transition them to failed
if the invoice cannot be paid.

Change-Id: I2b70a11c97ae5c733d05c918a1082e85bb7f73f3
This commit is contained in:
dlamarmorgan 2023-09-25 14:06:01 -07:00 committed by Damein Morgan
parent 2e87df380d
commit 8a1bedd367
8 changed files with 426 additions and 66 deletions

View File

@ -29,10 +29,10 @@ var ErrNoTransactions = errs.New("no transactions in the database")
const ( const (
// TransactionStatusPending indicates that status of this transaction is pending. // TransactionStatusPending indicates that status of this transaction is pending.
TransactionStatusPending = "pending" TransactionStatusPending = "pending"
// TransactionStatusCancelled indicates that status of this transaction is cancelled.
TransactionStatusCancelled = "cancelled"
// TransactionStatusCompleted indicates that status of this transaction is complete. // TransactionStatusCompleted indicates that status of this transaction is complete.
TransactionStatusCompleted = "complete" TransactionStatusCompleted = "complete"
// TransactionStatusFailed indicates that status of this transaction is failed.
TransactionStatusFailed = "failed"
) )
// TransactionType indicates transaction type. // TransactionType indicates transaction type.
@ -57,8 +57,10 @@ type TransactionsDB interface {
// but rather to provide an atomic commit of one or more _related_ // but rather to provide an atomic commit of one or more _related_
// transactions. // transactions.
Insert(ctx context.Context, primaryTx Transaction, supplementalTx ...Transaction) (txIDs []int64, err error) Insert(ctx context.Context, primaryTx Transaction, supplementalTx ...Transaction) (txIDs []int64, err error)
// UpdateStatus updates the status of the transaction. // FailPendingInvoiceTokenPayments marks all specified pending invoice token payments as failed, and refunds the pending charges.
UpdateStatus(ctx context.Context, txID int64, status TransactionStatus) error FailPendingInvoiceTokenPayments(ctx context.Context, txIDs ...int64) error
// CompletePendingInvoiceTokenPayments updates the status of the pending invoice token payment to complete.
CompletePendingInvoiceTokenPayments(ctx context.Context, txIDs ...int64) error
// UpdateMetadata updates the metadata of the transaction. // UpdateMetadata updates the metadata of the transaction.
UpdateMetadata(ctx context.Context, txID int64, metadata []byte) error UpdateMetadata(ctx context.Context, txID int64, metadata []byte) error
// LastTransaction returns the timestamp and metadata of the last known transaction for given source and type. // LastTransaction returns the timestamp and metadata of the last known transaction for given source and type.

View File

@ -16,6 +16,7 @@ import (
"storj.io/common/testrand" "storj.io/common/testrand"
"storj.io/storj/private/blockchain" "storj.io/storj/private/blockchain"
"storj.io/storj/satellite" "storj.io/storj/satellite"
"storj.io/storj/satellite/payments"
"storj.io/storj/satellite/payments/billing" "storj.io/storj/satellite/payments/billing"
"storj.io/storj/satellite/satellitedb/satellitedbtest" "storj.io/storj/satellite/satellitedb/satellitedbtest"
) )
@ -102,9 +103,12 @@ func TestTransactionsDBBalance(t *testing.T) {
address, err := blockchain.BytesToAddress(testrand.Bytes(20)) address, err := blockchain.BytesToAddress(testrand.Bytes(20))
require.NoError(t, err) require.NoError(t, err)
metadata, err := json.Marshal(map[string]interface{}{ creditMetadata, err := json.Marshal(map[string]interface{}{
"Wallet": address.Hex(),
})
require.NoError(t, err)
debitMetadata, err := json.Marshal(map[string]interface{}{
"ReferenceID": "some stripe invoice ID", "ReferenceID": "some stripe invoice ID",
"Wallet": address.Hex(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -115,7 +119,7 @@ func TestTransactionsDBBalance(t *testing.T) {
Source: "storjscan", Source: "storjscan",
Status: billing.TransactionStatusCompleted, Status: billing.TransactionStatusCompleted,
Type: billing.TransactionTypeCredit, Type: billing.TransactionTypeCredit,
Metadata: metadata, Metadata: creditMetadata,
Timestamp: makeTimestamp().Add(time.Second), Timestamp: makeTimestamp().Add(time.Second),
} }
@ -126,7 +130,7 @@ func TestTransactionsDBBalance(t *testing.T) {
Source: "storjscan", Source: "storjscan",
Status: billing.TransactionStatusCompleted, Status: billing.TransactionStatusCompleted,
Type: billing.TransactionTypeCredit, Type: billing.TransactionTypeCredit,
Metadata: metadata, Metadata: creditMetadata,
Timestamp: makeTimestamp().Add(time.Second * 2), Timestamp: makeTimestamp().Add(time.Second * 2),
} }
@ -137,7 +141,7 @@ func TestTransactionsDBBalance(t *testing.T) {
Source: "storjscan", Source: "storjscan",
Status: billing.TransactionStatusCompleted, Status: billing.TransactionStatusCompleted,
Type: billing.TransactionTypeDebit, Type: billing.TransactionTypeDebit,
Metadata: metadata, Metadata: debitMetadata,
Timestamp: makeTimestamp().Add(time.Second * 3), Timestamp: makeTimestamp().Add(time.Second * 3),
} }
@ -195,13 +199,17 @@ func TestTransactionsDBBalance(t *testing.T) {
func TestUpdateTransactions(t *testing.T) { func TestUpdateTransactions(t *testing.T) {
tenUSD := currency.AmountFromBaseUnits(1000, currency.USDollars) tenUSD := currency.AmountFromBaseUnits(1000, currency.USDollars)
minusTenUSD := currency.AmountFromBaseUnits(-1000, currency.USDollars)
userID := testrand.UUID() userID := testrand.UUID()
address, err := blockchain.BytesToAddress(testrand.Bytes(20)) address, err := blockchain.BytesToAddress(testrand.Bytes(20))
require.NoError(t, err) require.NoError(t, err)
metadata, err := json.Marshal(map[string]interface{}{ creditMetadata, err := json.Marshal(map[string]interface{}{
"Wallet": address.Hex(),
})
require.NoError(t, err)
debitMetadata, err := json.Marshal(map[string]interface{}{
"ReferenceID": "some stripe invoice ID", "ReferenceID": "some stripe invoice ID",
"Wallet": address.Hex(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -209,44 +217,54 @@ func TestUpdateTransactions(t *testing.T) {
UserID: userID, UserID: userID,
Amount: tenUSD, Amount: tenUSD,
Description: "credit from storjscan payment", Description: "credit from storjscan payment",
Source: "storjscan", Source: billing.StorjScanSource,
Status: billing.TransactionStatusCompleted, Status: payments.PaymentStatusConfirmed,
Type: billing.TransactionTypeCredit, Type: billing.TransactionTypeCredit,
Metadata: metadata, Metadata: creditMetadata,
Timestamp: makeTimestamp().Add(time.Second),
}
debit10TX := billing.Transaction{
UserID: userID,
Amount: minusTenUSD,
Description: "Paid Stripe Invoice",
Source: billing.StripeSource,
Status: billing.TransactionStatusPending,
Type: billing.TransactionTypeDebit,
Metadata: debitMetadata,
Timestamp: makeTimestamp().Add(time.Second), Timestamp: makeTimestamp().Add(time.Second),
} }
t.Run("update metadata", func(t *testing.T) { t.Run("update metadata", func(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
txIDs, err := db.Billing().Insert(ctx, credit10TX) _, err := db.Billing().Insert(ctx, credit10TX)
require.NoError(t, err) require.NoError(t, err)
newAddress, err := blockchain.BytesToAddress(testrand.Bytes(20)) txIDs, err := db.Billing().Insert(ctx, debit10TX)
require.NoError(t, err) require.NoError(t, err)
metadata, err := json.Marshal(map[string]interface{}{ metadata, err := json.Marshal(map[string]interface{}{
"Wallet": newAddress.Hex(), "ReferenceID": "some other stripe invoice ID",
}) })
require.NoError(t, err) require.NoError(t, err)
err = db.Billing().UpdateMetadata(ctx, txIDs[0], metadata) err = db.Billing().UpdateMetadata(ctx, txIDs[0], metadata)
require.NoError(t, err) require.NoError(t, err)
expMetadata, err := json.Marshal(map[string]interface{}{ expMetadata, err := json.Marshal(map[string]interface{}{
"ReferenceID": "some stripe invoice ID", "ReferenceID": "some other stripe invoice ID",
"Wallet": newAddress.Hex(),
}) })
require.NoError(t, err) require.NoError(t, err)
credit10TX.Metadata = expMetadata debit10TX.Metadata = expMetadata
tx, err := db.Billing().List(ctx, userID) tx, err := db.Billing().List(ctx, userID)
require.NoError(t, err) require.NoError(t, err)
compareTransactions(t, credit10TX, tx[0]) assert.Equal(t, 2, compareMultipleTransactions(t,
[]billing.Transaction{credit10TX, debit10TX},
tx))
}) })
}) })
t.Run("update status", func(t *testing.T) { t.Run("confirm new token deposit", func(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
txIDs, err := db.Billing().Insert(ctx, credit10TX) _, err := db.Billing().Insert(ctx, credit10TX)
require.NoError(t, err) require.NoError(t, err)
err = db.Billing().UpdateStatus(ctx, txIDs[0], billing.TransactionStatusCancelled) credit10TX.Status = payments.PaymentStatusConfirmed
require.NoError(t, err)
credit10TX.Status = billing.TransactionStatusCancelled
tx, err := db.Billing().List(ctx, userID) tx, err := db.Billing().List(ctx, userID)
require.NoError(t, err) require.NoError(t, err)
compareTransactions(t, credit10TX, tx[0]) compareTransactions(t, credit10TX, tx[0])
@ -254,6 +272,133 @@ func TestUpdateTransactions(t *testing.T) {
}) })
} }
func TestCompletePendingPayment(t *testing.T) {
tenUSD := currency.AmountFromBaseUnits(1000, currency.USDollars)
minusTenUSD := currency.AmountFromBaseUnits(-1000, currency.USDollars)
userID := testrand.UUID()
address, err := blockchain.BytesToAddress(testrand.Bytes(20))
require.NoError(t, err)
creditMetadata, err := json.Marshal(map[string]interface{}{
"Wallet": address.Hex(),
})
require.NoError(t, err)
debitMetadata, err := json.Marshal(map[string]interface{}{
"ReferenceID": "some stripe invoice ID",
})
require.NoError(t, err)
credit10TX := billing.Transaction{
UserID: userID,
Amount: tenUSD,
Description: "credit from storjscan payment",
Source: billing.StorjScanSource,
Status: payments.PaymentStatusConfirmed,
Type: billing.TransactionTypeCredit,
Metadata: creditMetadata,
Timestamp: makeTimestamp().Add(time.Second),
}
debit10TX := billing.Transaction{
UserID: userID,
Amount: minusTenUSD,
Description: "Paid Stripe Invoice",
Source: billing.StripeSource,
Status: billing.TransactionStatusPending,
Type: billing.TransactionTypeDebit,
Metadata: debitMetadata,
Timestamp: makeTimestamp().Add(time.Second),
}
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
_, err := db.Billing().Insert(ctx, credit10TX)
require.NoError(t, err)
credit10TX.Status = payments.PaymentStatusConfirmed
tx, err := db.Billing().List(ctx, userID)
require.NoError(t, err)
compareTransactions(t, credit10TX, tx[0])
txIDs, err := db.Billing().Insert(ctx, debit10TX)
require.NoError(t, err)
err = db.Billing().CompletePendingInvoiceTokenPayments(ctx, txIDs[0])
require.NoError(t, err)
debit10TX.Status = billing.TransactionStatusCompleted
tx, err = db.Billing().List(ctx, userID)
require.NoError(t, err)
assert.Equal(t, 2, compareMultipleTransactions(t,
[]billing.Transaction{credit10TX, debit10TX}, tx))
})
}
func TestFailPendingPayment(t *testing.T) {
tenUSD := currency.AmountFromBaseUnits(1000, currency.USDollars)
minusTenUSD := currency.AmountFromBaseUnits(-1000, currency.USDollars)
userID := testrand.UUID()
address, err := blockchain.BytesToAddress(testrand.Bytes(20))
require.NoError(t, err)
creditMetadata, err := json.Marshal(map[string]interface{}{
"Wallet": address.Hex(),
})
require.NoError(t, err)
debitMetadata, err := json.Marshal(map[string]interface{}{
"ReferenceID": "some stripe invoice ID",
})
require.NoError(t, err)
credit10TX := billing.Transaction{
UserID: userID,
Amount: tenUSD,
Description: "credit from storjscan payment",
Source: billing.StorjScanSource,
Status: payments.PaymentStatusConfirmed,
Type: billing.TransactionTypeCredit,
Metadata: creditMetadata,
Timestamp: makeTimestamp().Add(time.Second),
}
debit10TX := billing.Transaction{
UserID: userID,
Amount: minusTenUSD,
Description: "Paid Stripe Invoice",
Source: billing.StripeSource,
Status: billing.TransactionStatusPending,
Type: billing.TransactionTypeDebit,
Metadata: debitMetadata,
Timestamp: makeTimestamp().Add(time.Second),
}
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
_, err := db.Billing().Insert(ctx, credit10TX)
require.NoError(t, err)
credit10TX.Status = payments.PaymentStatusConfirmed
tx, err := db.Billing().List(ctx, userID)
require.NoError(t, err)
compareTransactions(t, credit10TX, tx[0])
txIDs, err := db.Billing().Insert(ctx, debit10TX)
require.NoError(t, err)
err = db.Billing().FailPendingInvoiceTokenPayments(ctx, txIDs[0])
require.NoError(t, err)
debit10TX.Status = billing.TransactionStatusFailed
tx, err = db.Billing().List(ctx, userID)
require.NoError(t, err)
assert.Equal(t, 2, compareMultipleTransactions(t,
[]billing.Transaction{credit10TX, debit10TX}, tx))
})
}
func compareMultipleTransactions(t *testing.T, exp, act []billing.Transaction) int {
var matches = 0
for _, expectedTx := range exp {
for _, actualTX := range act {
if expectedTx.Description == actualTX.Description {
matches++
compareTransactions(t, expectedTx, actualTX)
}
}
}
return matches
}
// compareTransactions is a helper method to compare tx used to create db entry, // compareTransactions is a helper method to compare tx used to create db entry,
// with the tx returned from the db. Method doesn't compare created at field, but // with the tx returned from the db. Method doesn't compare created at field, but
// ensures that is not empty. // ensures that is not empty.
@ -272,7 +417,6 @@ func compareTransactions(t *testing.T, exp, act billing.Transaction) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expUpdatedMetadata["ReferenceID"], actUpdatedMetadata["ReferenceID"]) assert.Equal(t, expUpdatedMetadata["ReferenceID"], actUpdatedMetadata["ReferenceID"])
assert.Equal(t, expUpdatedMetadata["Wallet"], actUpdatedMetadata["Wallet"]) assert.Equal(t, expUpdatedMetadata["Wallet"], actUpdatedMetadata["Wallet"])
assert.Equal(t, exp.Timestamp, act.Timestamp)
assert.NotEqual(t, time.Time{}, act.CreatedAt) assert.NotEqual(t, time.Time{}, act.CreatedAt)
} }

View File

@ -1220,6 +1220,11 @@ func (service *Service) payInvoicesWithTokenBalance(ctx context.Context, cusID s
creditNoteID, err := service.addCreditNoteToInvoice(ctx, invoice.ID, cusID, wallet.Address.Hex(), tokenCreditAmount, txID) creditNoteID, err := service.addCreditNoteToInvoice(ctx, invoice.ID, cusID, wallet.Address.Hex(), tokenCreditAmount, txID)
if err != nil { if err != nil {
// attempt to fail any pending transactions
err := service.billingDB.FailPendingInvoiceTokenPayments(ctx, txID)
if err != nil {
errGrp.Add(Error.New("unable to fail the pending transactions for user %s", wallet.UserID.String()))
}
errGrp.Add(Error.New("unable to create token payment credit note for user %s", wallet.UserID.String())) errGrp.Add(Error.New("unable to create token payment credit note for user %s", wallet.UserID.String()))
continue continue
} }
@ -1229,18 +1234,33 @@ func (service *Service) payInvoicesWithTokenBalance(ctx context.Context, cusID s
}) })
if err != nil { if err != nil {
// attempt to fail any pending transactions
err := service.billingDB.FailPendingInvoiceTokenPayments(ctx, txID)
if err != nil {
errGrp.Add(Error.New("unable to fail the pending transactions for user %s", wallet.UserID.String()))
}
errGrp.Add(Error.New("unable to marshall credit note ID %s", creditNoteID)) errGrp.Add(Error.New("unable to marshall credit note ID %s", creditNoteID))
continue continue
} }
err = service.billingDB.UpdateMetadata(ctx, txID, metadata) err = service.billingDB.UpdateMetadata(ctx, txID, metadata)
if err != nil { if err != nil {
// attempt to fail any pending transactions
err := service.billingDB.FailPendingInvoiceTokenPayments(ctx, txID)
if err != nil {
errGrp.Add(Error.New("unable to fail the pending transactions for user %s", wallet.UserID.String()))
}
errGrp.Add(Error.New("unable to add credit note ID to billing transaction for user %s", wallet.UserID.String())) errGrp.Add(Error.New("unable to add credit note ID to billing transaction for user %s", wallet.UserID.String()))
continue continue
} }
err = service.billingDB.UpdateStatus(ctx, txID, billing.TransactionStatusCompleted) err = service.billingDB.CompletePendingInvoiceTokenPayments(ctx, txID)
if err != nil { if err != nil {
// attempt to fail any pending transactions
err := service.billingDB.FailPendingInvoiceTokenPayments(ctx, txID)
if err != nil {
errGrp.Add(Error.New("unable to fail the pending transactions for user %s", wallet.UserID.String()))
}
errGrp.Add(Error.New("unable to update status for billing transaction for user %s", wallet.UserID.String())) errGrp.Add(Error.New("unable to update status for billing transaction for user %s", wallet.UserID.String()))
continue continue
} }

View File

@ -1147,6 +1147,92 @@ func TestService_PayMultipleInvoiceForCustomer(t *testing.T) {
}) })
} }
func TestFailPendingInvoicePayment(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Payments.StripeCoinPayments.ListingLimit = 4
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
satellite := planet.Satellites[0]
payments := satellite.API.Payments
tokenBalance := currency.AmountFromBaseUnits(1000, currency.USDollars)
invoiceBalance := currency.AmountFromBaseUnits(800, currency.USDollars)
usdCurrency := string(stripe.CurrencyUSD)
user, err := satellite.AddUser(ctx, console.CreateUser{
FullName: "testuser",
Email: "user@test",
}, 1)
require.NoError(t, err)
customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID)
require.NoError(t, err)
// create invoice
inv, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{
Params: stripe.Params{Context: ctx},
Customer: &customer,
DefaultPaymentMethod: stripe.String(stripe1.MockInvoicesPaySuccess),
Metadata: map[string]string{"mock": stripe1.MockInvoicesPayFailure},
})
require.NoError(t, err)
_, err = satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{
Params: stripe.Params{Context: ctx},
Amount: stripe.Int64(invoiceBalance.BaseUnits()),
Currency: stripe.String(usdCurrency),
Customer: &customer,
Invoice: stripe.String(inv.ID),
})
require.NoError(t, err)
// finalize invoice
err = satellite.API.Payments.StripeService.FinalizeInvoices(ctx)
require.NoError(t, err)
require.Equal(t, stripe.InvoiceStatusOpen, inv.Status)
// setup storjscan wallet
address, err := blockchain.BytesToAddress(testrand.Bytes(20))
require.NoError(t, err)
userID := user.ID
err = satellite.DB.Wallets().Add(ctx, userID, address)
require.NoError(t, err)
_, err = satellite.DB.Billing().Insert(ctx, billing.Transaction{
UserID: userID,
Amount: tokenBalance,
Description: "token payment credit",
Source: billing.StorjScanSource,
Status: billing.TransactionStatusCompleted,
Type: billing.TransactionTypeCredit,
Metadata: nil,
Timestamp: time.Now(),
CreatedAt: time.Now(),
})
require.NoError(t, err)
// run apply token balance to see if there are no unexpected errors
err = payments.StripeService.InvoiceApplyTokenBalance(ctx, time.Time{})
require.Error(t, err)
iter := satellite.API.Payments.StripeClient.Invoices().List(&stripe.InvoiceListParams{
ListParams: stripe.ListParams{Context: ctx},
})
iter.Next()
require.Equal(t, stripe.InvoiceStatusOpen, iter.Invoice().Status)
// balance is in USDollars Micro, so it needs to be converted before comparison
balance, err := satellite.DB.Billing().GetBalance(ctx, userID)
balance = currency.AmountFromDecimal(balance.AsDecimal().Truncate(2), currency.USDollars)
require.NoError(t, err)
// verify user balance wasn't changed
require.Equal(t, tokenBalance.BaseUnits(), balance.BaseUnits())
})
}
func TestService_GenerateInvoice(t *testing.T) { func TestService_GenerateInvoice(t *testing.T) {
for _, testCase := range []struct { for _, testCase := range []struct {
desc string desc string

View File

@ -590,6 +590,7 @@ func (m *mockInvoices) New(params *stripe.InvoiceParams) (*stripe.Invoice, error
DueDate: due, DueDate: due,
Status: stripe.InvoiceStatusDraft, Status: stripe.InvoiceStatusDraft,
Description: desc, Description: desc,
Metadata: params.Metadata,
Lines: &stripe.InvoiceLineItemList{ Lines: &stripe.InvoiceLineItemList{
Data: lineData, Data: lineData,
}, },
@ -993,6 +994,9 @@ func (m mockCreditNotes) New(params *stripe.CreditNoteParams) (*stripe.CreditNot
// but we don't need to support that in the mock right now // but we don't need to support that in the mock right now
return nil, &stripe.Error{} return nil, &stripe.Error{}
} }
if inv.Metadata["mock"] == MockInvoicesPayFailure {
return nil, errors.New("mock - failed to pay invoice")
}
invoice = inv invoice = inv
break break
} }

View File

@ -29,6 +29,30 @@ type billingDB struct {
db *satelliteDB db *satelliteDB
} }
func updateBalance(ctx context.Context, tx *dbx.Tx, userID uuid.UUID, oldBalance, newBalance currency.Amount) error {
updatedRow, err := tx.Update_BillingBalance_By_UserId_And_Balance(ctx,
dbx.BillingBalance_UserId(userID[:]),
dbx.BillingBalance_Balance(oldBalance.BaseUnits()),
dbx.BillingBalance_Update_Fields{
Balance: dbx.BillingBalance_Balance(newBalance.BaseUnits()),
})
if err != nil {
return Error.Wrap(err)
}
if updatedRow == nil {
// Try an insert here, in case the user never had a record in the table.
// If the user already had a record, and the oldBalance was not as expected,
// the insert will fail anyways.
err = tx.CreateNoReturn_BillingBalance(ctx,
dbx.BillingBalance_UserId(userID[:]),
dbx.BillingBalance_Balance(newBalance.BaseUnits()))
if err != nil {
return Error.Wrap(err)
}
}
return nil
}
func (db billingDB) Insert(ctx context.Context, primaryTx billing.Transaction, supplementalTxs ...billing.Transaction) (_ []int64, err error) { func (db billingDB) Insert(ctx context.Context, primaryTx billing.Transaction, supplementalTxs ...billing.Transaction) (_ []int64, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
@ -73,30 +97,6 @@ func (db billingDB) tryInsert(ctx context.Context, primaryTx billing.Transaction
NewBalance currency.Amount NewBalance currency.Amount
} }
updateBalance := func(ctx context.Context, tx *dbx.Tx, userID uuid.UUID, oldBalance, newBalance currency.Amount) error {
updatedRow, err := tx.Update_BillingBalance_By_UserId_And_Balance(ctx,
dbx.BillingBalance_UserId(userID[:]),
dbx.BillingBalance_Balance(oldBalance.BaseUnits()),
dbx.BillingBalance_Update_Fields{
Balance: dbx.BillingBalance_Balance(newBalance.BaseUnits()),
})
if err != nil {
return Error.Wrap(err)
}
if updatedRow == nil {
// Try an insert here, in case the user never had a record in the table.
// If the user already had a record, and the oldBalance was not as expected,
// the insert will fail anyways.
err = tx.CreateNoReturn_BillingBalance(ctx,
dbx.BillingBalance_UserId(userID[:]),
dbx.BillingBalance_Balance(newBalance.BaseUnits()))
if err != nil {
return Error.Wrap(err)
}
}
return nil
}
createTransaction := func(ctx context.Context, tx *dbx.Tx, billingTX *billing.Transaction) (int64, error) { createTransaction := func(ctx context.Context, tx *dbx.Tx, billingTX *billing.Transaction) (int64, error) {
amount := convertToUSDMicro(billingTX.Amount) amount := convertToUSDMicro(billingTX.Amount)
dbxTX, err := tx.Create_BillingTransaction(ctx, dbxTX, err := tx.Create_BillingTransaction(ctx,
@ -173,11 +173,56 @@ func (db billingDB) tryInsert(ctx context.Context, primaryTx billing.Transaction
return txIDs, err return txIDs, err
} }
func (db billingDB) UpdateStatus(ctx context.Context, txID int64, status billing.TransactionStatus) (err error) { func (db billingDB) FailPendingInvoiceTokenPayments(ctx context.Context, txIDs ...int64) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
return db.db.UpdateNoReturn_BillingTransaction_By_Id(ctx, dbx.BillingTransaction_Id(txID), dbx.BillingTransaction_Update_Fields{
Status: dbx.BillingTransaction_Status(string(status)), for _, txID := range txIDs {
}) dbxTX, err := db.db.Get_BillingTransaction_By_Id(ctx, dbx.BillingTransaction_Id(txID))
if err != nil {
return Error.Wrap(err)
}
userID, err := uuid.FromBytes(dbxTX.UserId)
if err != nil {
return Error.New("Unable to get user ID for transaction: %v %v", txID, err)
}
oldBalance, err := db.GetBalance(ctx, userID)
if err != nil {
return Error.New("Unable to get user balance for ID: %v %v", userID, err)
}
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
err = db.db.UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx, dbx.BillingTransaction_Id(txID),
dbx.BillingTransaction_Status(billing.TransactionStatusPending),
dbx.BillingTransaction_Update_Fields{
Status: dbx.BillingTransaction_Status(billing.TransactionStatusFailed),
})
if err != nil {
return Error.Wrap(err)
}
// refund the pending charge. dbx amount is negative.
return updateBalance(ctx, tx, userID, oldBalance, currency.AmountFromBaseUnits(oldBalance.BaseUnits()-dbxTX.Amount, currency.USDollarsMicro))
})
if err != nil {
return Error.New("Unable to transition token invoice payment to failed state for transaction: %v %v", txID, err)
}
}
return nil
}
func (db billingDB) CompletePendingInvoiceTokenPayments(ctx context.Context, txIDs ...int64) (err error) {
defer mon.Task()(&ctx)(&err)
for _, txID := range txIDs {
err = db.db.UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx, dbx.BillingTransaction_Id(txID),
dbx.BillingTransaction_Status(billing.TransactionStatusPending),
dbx.BillingTransaction_Update_Fields{
Status: dbx.BillingTransaction_Status(billing.TransactionStatusCompleted),
})
if err != nil {
return Error.Wrap(err)
}
}
return nil
} }
func (db billingDB) UpdateMetadata(ctx context.Context, txID int64, newMetadata []byte) (err error) { func (db billingDB) UpdateMetadata(ctx context.Context, txID int64, newMetadata []byte) (err error) {
@ -192,9 +237,11 @@ func (db billingDB) UpdateMetadata(ctx context.Context, txID int64, newMetadata
return Error.Wrap(err) return Error.Wrap(err)
} }
return db.db.UpdateNoReturn_BillingTransaction_By_Id(ctx, dbx.BillingTransaction_Id(txID), dbx.BillingTransaction_Update_Fields{ return db.db.UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx, dbx.BillingTransaction_Id(txID),
Metadata: dbx.BillingTransaction_Metadata(updatedMetadata), dbx.BillingTransaction_Status(billing.TransactionStatusPending),
}) dbx.BillingTransaction_Update_Fields{
Metadata: dbx.BillingTransaction_Metadata(updatedMetadata),
})
} }
func (db billingDB) LastTransaction(ctx context.Context, txSource string, txType billing.TransactionType) (_ time.Time, metadata []byte, err error) { func (db billingDB) LastTransaction(ctx context.Context, txSource string, txType billing.TransactionType) (_ time.Time, metadata []byte, err error) {

View File

@ -96,9 +96,15 @@ create billing_transaction ( )
update billing_transaction ( update billing_transaction (
where billing_transaction.id = ? where billing_transaction.id = ?
where billing_transaction.status = ?
noreturn noreturn
) )
read one (
select billing_transaction
where billing_transaction.id = ?
)
read one ( read one (
select billing_transaction.metadata select billing_transaction.metadata
where billing_transaction.id = ? where billing_transaction.id = ?

View File

@ -14288,6 +14288,28 @@ func (obj *pgxImpl) Get_BillingBalance_Balance_By_UserId(ctx context.Context,
} }
func (obj *pgxImpl) Get_BillingTransaction_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) (
billing_transaction *BillingTransaction, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT billing_transactions.id, billing_transactions.user_id, billing_transactions.amount, billing_transactions.currency, billing_transactions.description, billing_transactions.source, billing_transactions.status, billing_transactions.type, billing_transactions.metadata, billing_transactions.timestamp, billing_transactions.created_at FROM billing_transactions WHERE billing_transactions.id = ?")
var __values []interface{}
__values = append(__values, billing_transaction_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
billing_transaction = &BillingTransaction{}
err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&billing_transaction.Id, &billing_transaction.UserId, &billing_transaction.Amount, &billing_transaction.Currency, &billing_transaction.Description, &billing_transaction.Source, &billing_transaction.Status, &billing_transaction.Type, &billing_transaction.Metadata, &billing_transaction.Timestamp, &billing_transaction.CreatedAt)
if err != nil {
return (*BillingTransaction)(nil), obj.makeErr(err)
}
return billing_transaction, nil
}
func (obj *pgxImpl) Get_BillingTransaction_Metadata_By_Id(ctx context.Context, func (obj *pgxImpl) Get_BillingTransaction_Metadata_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) ( billing_transaction_id BillingTransaction_Id_Field) (
row *Metadata_Row, err error) { row *Metadata_Row, err error) {
@ -17136,14 +17158,15 @@ func (obj *pgxImpl) Update_BillingBalance_By_UserId_And_Balance(ctx context.Cont
return billing_balance, nil return billing_balance, nil
} }
func (obj *pgxImpl) UpdateNoReturn_BillingTransaction_By_Id(ctx context.Context, func (obj *pgxImpl) UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field, billing_transaction_id BillingTransaction_Id_Field,
billing_transaction_status BillingTransaction_Status_Field,
update BillingTransaction_Update_Fields) ( update BillingTransaction_Update_Fields) (
err error) { err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
var __sets = &__sqlbundle_Hole{} var __sets = &__sqlbundle_Hole{}
var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE billing_transactions SET "), __sets, __sqlbundle_Literal(" WHERE billing_transactions.id = ?")}} var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE billing_transactions SET "), __sets, __sqlbundle_Literal(" WHERE billing_transactions.id = ? AND billing_transactions.status = ?")}}
__sets_sql := __sqlbundle_Literals{Join: ", "} __sets_sql := __sqlbundle_Literals{Join: ", "}
var __values []interface{} var __values []interface{}
@ -17163,7 +17186,7 @@ func (obj *pgxImpl) UpdateNoReturn_BillingTransaction_By_Id(ctx context.Context,
return emptyUpdate() return emptyUpdate()
} }
__args = append(__args, billing_transaction_id.value()) __args = append(__args, billing_transaction_id.value(), billing_transaction_status.value())
__values = append(__values, __args...) __values = append(__values, __args...)
__sets.SQL = __sets_sql __sets.SQL = __sets_sql
@ -22423,6 +22446,28 @@ func (obj *pgxcockroachImpl) Get_BillingBalance_Balance_By_UserId(ctx context.Co
} }
func (obj *pgxcockroachImpl) Get_BillingTransaction_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) (
billing_transaction *BillingTransaction, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT billing_transactions.id, billing_transactions.user_id, billing_transactions.amount, billing_transactions.currency, billing_transactions.description, billing_transactions.source, billing_transactions.status, billing_transactions.type, billing_transactions.metadata, billing_transactions.timestamp, billing_transactions.created_at FROM billing_transactions WHERE billing_transactions.id = ?")
var __values []interface{}
__values = append(__values, billing_transaction_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
billing_transaction = &BillingTransaction{}
err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&billing_transaction.Id, &billing_transaction.UserId, &billing_transaction.Amount, &billing_transaction.Currency, &billing_transaction.Description, &billing_transaction.Source, &billing_transaction.Status, &billing_transaction.Type, &billing_transaction.Metadata, &billing_transaction.Timestamp, &billing_transaction.CreatedAt)
if err != nil {
return (*BillingTransaction)(nil), obj.makeErr(err)
}
return billing_transaction, nil
}
func (obj *pgxcockroachImpl) Get_BillingTransaction_Metadata_By_Id(ctx context.Context, func (obj *pgxcockroachImpl) Get_BillingTransaction_Metadata_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) ( billing_transaction_id BillingTransaction_Id_Field) (
row *Metadata_Row, err error) { row *Metadata_Row, err error) {
@ -25271,14 +25316,15 @@ func (obj *pgxcockroachImpl) Update_BillingBalance_By_UserId_And_Balance(ctx con
return billing_balance, nil return billing_balance, nil
} }
func (obj *pgxcockroachImpl) UpdateNoReturn_BillingTransaction_By_Id(ctx context.Context, func (obj *pgxcockroachImpl) UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field, billing_transaction_id BillingTransaction_Id_Field,
billing_transaction_status BillingTransaction_Status_Field,
update BillingTransaction_Update_Fields) ( update BillingTransaction_Update_Fields) (
err error) { err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
var __sets = &__sqlbundle_Hole{} var __sets = &__sqlbundle_Hole{}
var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE billing_transactions SET "), __sets, __sqlbundle_Literal(" WHERE billing_transactions.id = ?")}} var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE billing_transactions SET "), __sets, __sqlbundle_Literal(" WHERE billing_transactions.id = ? AND billing_transactions.status = ?")}}
__sets_sql := __sqlbundle_Literals{Join: ", "} __sets_sql := __sqlbundle_Literals{Join: ", "}
var __values []interface{} var __values []interface{}
@ -25298,7 +25344,7 @@ func (obj *pgxcockroachImpl) UpdateNoReturn_BillingTransaction_By_Id(ctx context
return emptyUpdate() return emptyUpdate()
} }
__args = append(__args, billing_transaction_id.value()) __args = append(__args, billing_transaction_id.value(), billing_transaction_status.value())
__values = append(__values, __args...) __values = append(__values, __args...)
__sets.SQL = __sets_sql __sets.SQL = __sets_sql
@ -28829,6 +28875,10 @@ type Methods interface {
billing_balance_user_id BillingBalance_UserId_Field) ( billing_balance_user_id BillingBalance_UserId_Field) (
row *Balance_Row, err error) row *Balance_Row, err error)
Get_BillingTransaction_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) (
billing_transaction *BillingTransaction, err error)
Get_BillingTransaction_Metadata_By_Id(ctx context.Context, Get_BillingTransaction_Metadata_By_Id(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field) ( billing_transaction_id BillingTransaction_Id_Field) (
row *Metadata_Row, err error) row *Metadata_Row, err error)
@ -29169,8 +29219,9 @@ type Methods interface {
update ApiKey_Update_Fields) ( update ApiKey_Update_Fields) (
err error) err error)
UpdateNoReturn_BillingTransaction_By_Id(ctx context.Context, UpdateNoReturn_BillingTransaction_By_Id_And_Status(ctx context.Context,
billing_transaction_id BillingTransaction_Id_Field, billing_transaction_id BillingTransaction_Id_Field,
billing_transaction_status BillingTransaction_Status_Field,
update BillingTransaction_Update_Fields) ( update BillingTransaction_Update_Fields) (
err error) err error)