billing: add USDollarsMicro to billing DB

Adds USDollarsMicro currency to the billing DB which support fraction of a cent with decimal places for better billing amounts accuracy.

Change-Id: Id07dfae104d94e27c7b22ab8f5781010e16c4c8e
This commit is contained in:
dlamarmorgan 2022-08-26 16:08:03 -07:00
parent afe58323f9
commit 47301e5718
6 changed files with 100 additions and 24 deletions

View File

@ -56,7 +56,7 @@ type TransactionsDB interface {
// List returns all transactions for the specified user.
List(ctx context.Context, userID uuid.UUID) ([]Transaction, error)
// GetBalance returns the current usable balance for the specified user.
GetBalance(ctx context.Context, userID uuid.UUID) (int64, error)
GetBalance(ctx context.Context, userID uuid.UUID) (monetary.Amount, error)
}
// PaymentType is an interface which defines functionality required for all billing payment types. Payment types can

View File

@ -90,9 +90,10 @@ func TestTransactionsDBList(t *testing.T) {
func TestTransactionsDBBalance(t *testing.T) {
tenUSD := monetary.AmountFromBaseUnits(1000, monetary.USDollars)
twentyUSD := monetary.AmountFromBaseUnits(2000, monetary.USDollars)
tenMicroUSD := monetary.AmountFromBaseUnits(10000000, monetary.USDollarsMicro)
twentyMicroUSD := monetary.AmountFromBaseUnits(20000000, monetary.USDollarsMicro)
thirtyUSD := monetary.AmountFromBaseUnits(3000, monetary.USDollars)
fortyUSD := monetary.AmountFromBaseUnits(4000, monetary.USDollars)
fortyMicroUSD := monetary.AmountFromBaseUnits(40000000, monetary.USDollarsMicro)
negativeTwentyUSD := monetary.AmountFromBaseUnits(-2000, monetary.USDollars)
userID := testrand.UUID()
@ -152,7 +153,7 @@ func TestTransactionsDBBalance(t *testing.T) {
compareTransactions(t, credit10TX, txs[0])
balance, err := db.Billing().GetBalance(ctx, userID)
require.NoError(t, err)
require.Equal(t, tenUSD.BaseUnits(), balance)
require.Equal(t, tenMicroUSD.BaseUnits(), balance.BaseUnits())
})
})
@ -169,7 +170,7 @@ func TestTransactionsDBBalance(t *testing.T) {
compareTransactions(t, credit10TX, txs[1])
balance, err := db.Billing().GetBalance(ctx, userID)
require.NoError(t, err)
require.Equal(t, fortyUSD.BaseUnits(), balance)
require.Equal(t, fortyMicroUSD.BaseUnits(), balance.BaseUnits())
})
})
@ -189,7 +190,7 @@ func TestTransactionsDBBalance(t *testing.T) {
compareTransactions(t, credit10TX, txs[2])
balance, err := db.Billing().GetBalance(ctx, userID)
require.NoError(t, err)
require.Equal(t, twentyUSD.BaseUnits(), balance)
require.Equal(t, twentyMicroUSD.BaseUnits(), balance.BaseUnits())
})
})
}
@ -265,7 +266,7 @@ func TestUpdateMetadata(t *testing.T) {
// ensures that is not empty.
func compareTransactions(t *testing.T, exp, act billing.Transaction) {
assert.Equal(t, exp.UserID, act.UserID)
assert.Equal(t, exp.Amount, act.Amount)
assert.Equal(t, monetary.AmountFromDecimal(exp.Amount.AsDecimal().Truncate(monetary.USDollarsMicro.DecimalPlaces()), monetary.USDollarsMicro), act.Amount)
assert.Equal(t, exp.Description, act.Description)
assert.Equal(t, exp.Status, act.Status)
assert.Equal(t, exp.Source, act.Source)

View File

@ -37,6 +37,16 @@ func (c *Currency) Symbol() string {
return c.symbol
}
// DecimalPlaces returns the decimal places of the currency.
func (c *Currency) DecimalPlaces() int32 {
return c.decimalPlaces
}
// Zero returns the zero value of the currency.
func (c *Currency) Zero() Amount {
return AmountFromBaseUnits(0, c)
}
var (
// StorjToken is the currency for the STORJ ERC20 token, which powers
// most payments on the current Storj network.
@ -169,6 +179,11 @@ func (a Amount) MarshalJSON() ([]byte, error) {
return json.Marshal(amountJSON)
}
// IsNegative returns true if the base unit amount is negative.
func (a Amount) IsNegative() bool {
return a.baseUnits < 0
}
// AmountFromBaseUnits creates a new Amount instance from the given count of
// base units and in the given currency.
func AmountFromBaseUnits(units int64, currency *Currency) Amount {
@ -224,3 +239,25 @@ func DecimalFromBigFloat(f *big.Float) (decimal.Decimal, error) {
dec, err := decimal.NewFromString(stringVal)
return dec, Error.Wrap(err)
}
// Add adds two monetary amounts and returns the result. If the currencies are different, an error is thrown.
func Add(i, j Amount) (Amount, error) {
if !sameCurrency(i.currency, j.currency) {
return i.currency.Zero(), errs.New("Amounts to add must use the same currency")
}
return AmountFromBaseUnits(i.baseUnits+j.baseUnits, i.currency), nil
}
// Greater returns true if the first monetary amount is greater than the second.
// If the currencies are different, an error is thrown.
func Greater(i, j Amount) (bool, error) {
if !sameCurrency(i.currency, j.currency) {
return false, errs.New("Amounts to compare must use the same currency")
}
return i.baseUnits > j.baseUnits, nil
}
// returns true if the currencies are the same and not nil.
func sameCurrency(i, j *Currency) bool {
return i != nil && j != nil && i == j
}

View File

@ -175,3 +175,33 @@ func TestAmountJSONUnmarshal(t *testing.T) {
require.Equal(t, test.Currency, amount.Currency())
}
}
func TestGreater(t *testing.T) {
oneHundredUSDMicro := AmountFromBaseUnits(100000000, USDollarsMicro)
twoHundredUSDMicro := AmountFromBaseUnits(200000000, USDollarsMicro)
checkGreater, err := Greater(twoHundredUSDMicro, oneHundredUSDMicro)
require.NoError(t, err)
require.True(t, checkGreater)
}
func TestAdd(t *testing.T) {
oneHundredUSD := AmountFromBaseUnits(10000, USDollars)
twoHundredUSD := AmountFromBaseUnits(20000, USDollars)
sum, err := Add(oneHundredUSD, oneHundredUSD)
require.NoError(t, err)
require.Equal(t, twoHundredUSD, sum)
}
func TestZero(t *testing.T) {
require.Equal(t, AmountFromBaseUnits(0, USDollarsMicro), USDollarsMicro.Zero())
require.Equal(t, AmountFromBaseUnits(0, USDollars), USDollars.Zero())
require.NotEqual(t, AmountFromBaseUnits(0, USDollars), USDollarsMicro.Zero())
}
func TestNegative(t *testing.T) {
require.False(t, USDollarsMicro.Zero().IsNegative())
require.False(t, USDollars.Zero().IsNegative())
require.True(t, AmountFromBaseUnits(-10000, USDollars).IsNegative())
}

View File

@ -557,12 +557,15 @@ func (service *Service) InvoiceApplyTokenBalance(ctx context.Context) (err error
for _, wallet := range wallets {
// get the user token balance, if it's not > 0, don't bother with the rest
tokenBalance, err := service.billingDB.GetBalance(ctx, wallet.UserID)
monetaryTokenBalance, err := service.billingDB.GetBalance(ctx, wallet.UserID)
// truncate here since stripe only has cent level precision for invoices.
// The users account balance will still maintain the full precision monetary value!
tokenBalance := monetary.AmountFromDecimal(monetaryTokenBalance.AsDecimal().Truncate(2), monetary.USDollars)
if err != nil {
errGrp.Add(Error.New("unable to compute balance for user ID %s", wallet.UserID.String()))
continue
}
if tokenBalance <= 0 {
if tokenBalance.BaseUnits() <= 0 {
continue
}
// get the stripe customer invoice balance
@ -583,8 +586,8 @@ func (service *Service) InvoiceApplyTokenBalance(ctx context.Context) (err error
}
var tokenCreditAmount int64
if invoice.AmountDue >= tokenBalance {
tokenCreditAmount = -tokenBalance
if invoice.AmountDue >= tokenBalance.BaseUnits() {
tokenCreditAmount = -tokenBalance.BaseUnits()
} else {
tokenCreditAmount = -invoice.AmountDue
}

View File

@ -34,31 +34,36 @@ func (db billingDB) Insert(ctx context.Context, billingTX billing.Transaction) (
var dbxTX *dbx.BillingTransaction
var retryCount int
for {
balance, err := db.GetBalance(ctx, billingTX.UserID)
oldBalance, err := db.GetBalance(ctx, billingTX.UserID)
if err != nil {
return 0, Error.Wrap(err)
}
if balance+billingTX.Amount.BaseUnits() < 0 {
billingAmount := monetary.AmountFromDecimal(billingTX.Amount.AsDecimal().Truncate(monetary.USDollarsMicro.DecimalPlaces()), monetary.USDollarsMicro)
newBalance, err := monetary.Add(oldBalance, billingAmount)
if err != nil {
return 0, Error.Wrap(err)
}
if newBalance.IsNegative() {
return 0, billing.ErrInsufficientFunds
}
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
updatedRow, err := tx.Update_BillingBalance_By_UserId_And_Balance(ctx,
dbx.BillingBalance_UserId(billingTX.UserID[:]),
dbx.BillingBalance_Balance(balance),
dbx.BillingBalance_Balance(oldBalance.BaseUnits()),
dbx.BillingBalance_Update_Fields{
Balance: dbx.BillingBalance_Balance(balance + billingTX.Amount.BaseUnits()),
Balance: dbx.BillingBalance_Balance(newBalance.BaseUnits()),
})
if err != nil {
return Error.Wrap(err)
}
if updatedRow == nil {
// Try an insert here, in case the user never had a record in the table.
// If the user already had a record, and the balance was not as expected,
// If the user already had a record, and the oldBalance was not as expected,
// the insert will fail anyways.
err = tx.CreateNoReturn_BillingBalance(ctx,
dbx.BillingBalance_UserId(billingTX.UserID[:]),
dbx.BillingBalance_Balance(balance+billingTX.Amount.BaseUnits()))
dbx.BillingBalance_Balance(newBalance.BaseUnits()))
if err != nil {
return Error.Wrap(err)
}
@ -66,8 +71,8 @@ func (db billingDB) Insert(ctx context.Context, billingTX billing.Transaction) (
dbxTX, err = tx.Create_BillingTransaction(ctx,
dbx.BillingTransaction_UserId(billingTX.UserID[:]),
dbx.BillingTransaction_Amount(billingTX.Amount.BaseUnits()),
dbx.BillingTransaction_Currency(monetary.USDollars.Symbol()),
dbx.BillingTransaction_Amount(billingAmount.BaseUnits()),
dbx.BillingTransaction_Currency(billingAmount.Currency().Symbol()),
dbx.BillingTransaction_Description(billingTX.Description),
dbx.BillingTransaction_Source(billingTX.Source),
dbx.BillingTransaction_Status(string(billingTX.Status)),
@ -155,18 +160,18 @@ func (db billingDB) List(ctx context.Context, userID uuid.UUID) (txs []billing.T
return txs, nil
}
func (db billingDB) GetBalance(ctx context.Context, userID uuid.UUID) (_ int64, err error) {
func (db billingDB) GetBalance(ctx context.Context, userID uuid.UUID) (_ monetary.Amount, err error) {
defer mon.Task()(&ctx)(&err)
dbxBilling, err := db.db.Get_BillingBalance_Balance_By_UserId(ctx,
dbx.BillingBalance_UserId(userID[:]))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
return monetary.USDollarsMicro.Zero(), nil
}
return 0, Error.Wrap(err)
return monetary.USDollarsMicro.Zero(), Error.Wrap(err)
}
return dbxBilling.Balance, nil
return monetary.AmountFromBaseUnits(dbxBilling.Balance, monetary.USDollarsMicro), nil
}
// fromDBXBillingTransaction converts *dbx.BillingTransaction to *billing.Transaction.
@ -178,7 +183,7 @@ func fromDBXBillingTransaction(dbxTX *dbx.BillingTransaction) (*billing.Transact
return &billing.Transaction{
ID: dbxTX.Id,
UserID: userID,
Amount: monetary.AmountFromBaseUnits(dbxTX.Amount, monetary.USDollars),
Amount: monetary.AmountFromBaseUnits(dbxTX.Amount, monetary.USDollarsMicro),
Description: dbxTX.Description,
Source: dbxTX.Source,
Status: billing.TransactionStatus(dbxTX.Status),