billing: add USDollarsMicro to billing DB
Adds USDollarsMicro currency to the billing DB which support fraction of a cent with decimal places for better billing amounts accuracy. Change-Id: Id07dfae104d94e27c7b22ab8f5781010e16c4c8e
This commit is contained in:
parent
afe58323f9
commit
47301e5718
@ -56,7 +56,7 @@ type TransactionsDB interface {
|
||||
// List returns all transactions for the specified user.
|
||||
List(ctx context.Context, userID uuid.UUID) ([]Transaction, error)
|
||||
// GetBalance returns the current usable balance for the specified user.
|
||||
GetBalance(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
GetBalance(ctx context.Context, userID uuid.UUID) (monetary.Amount, error)
|
||||
}
|
||||
|
||||
// PaymentType is an interface which defines functionality required for all billing payment types. Payment types can
|
||||
|
@ -90,9 +90,10 @@ func TestTransactionsDBList(t *testing.T) {
|
||||
|
||||
func TestTransactionsDBBalance(t *testing.T) {
|
||||
tenUSD := monetary.AmountFromBaseUnits(1000, monetary.USDollars)
|
||||
twentyUSD := monetary.AmountFromBaseUnits(2000, monetary.USDollars)
|
||||
tenMicroUSD := monetary.AmountFromBaseUnits(10000000, monetary.USDollarsMicro)
|
||||
twentyMicroUSD := monetary.AmountFromBaseUnits(20000000, monetary.USDollarsMicro)
|
||||
thirtyUSD := monetary.AmountFromBaseUnits(3000, monetary.USDollars)
|
||||
fortyUSD := monetary.AmountFromBaseUnits(4000, monetary.USDollars)
|
||||
fortyMicroUSD := monetary.AmountFromBaseUnits(40000000, monetary.USDollarsMicro)
|
||||
negativeTwentyUSD := monetary.AmountFromBaseUnits(-2000, monetary.USDollars)
|
||||
|
||||
userID := testrand.UUID()
|
||||
@ -152,7 +153,7 @@ func TestTransactionsDBBalance(t *testing.T) {
|
||||
compareTransactions(t, credit10TX, txs[0])
|
||||
balance, err := db.Billing().GetBalance(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tenUSD.BaseUnits(), balance)
|
||||
require.Equal(t, tenMicroUSD.BaseUnits(), balance.BaseUnits())
|
||||
})
|
||||
})
|
||||
|
||||
@ -169,7 +170,7 @@ func TestTransactionsDBBalance(t *testing.T) {
|
||||
compareTransactions(t, credit10TX, txs[1])
|
||||
balance, err := db.Billing().GetBalance(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fortyUSD.BaseUnits(), balance)
|
||||
require.Equal(t, fortyMicroUSD.BaseUnits(), balance.BaseUnits())
|
||||
})
|
||||
})
|
||||
|
||||
@ -189,7 +190,7 @@ func TestTransactionsDBBalance(t *testing.T) {
|
||||
compareTransactions(t, credit10TX, txs[2])
|
||||
balance, err := db.Billing().GetBalance(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, twentyUSD.BaseUnits(), balance)
|
||||
require.Equal(t, twentyMicroUSD.BaseUnits(), balance.BaseUnits())
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -265,7 +266,7 @@ func TestUpdateMetadata(t *testing.T) {
|
||||
// ensures that is not empty.
|
||||
func compareTransactions(t *testing.T, exp, act billing.Transaction) {
|
||||
assert.Equal(t, exp.UserID, act.UserID)
|
||||
assert.Equal(t, exp.Amount, act.Amount)
|
||||
assert.Equal(t, monetary.AmountFromDecimal(exp.Amount.AsDecimal().Truncate(monetary.USDollarsMicro.DecimalPlaces()), monetary.USDollarsMicro), act.Amount)
|
||||
assert.Equal(t, exp.Description, act.Description)
|
||||
assert.Equal(t, exp.Status, act.Status)
|
||||
assert.Equal(t, exp.Source, act.Source)
|
||||
|
@ -37,6 +37,16 @@ func (c *Currency) Symbol() string {
|
||||
return c.symbol
|
||||
}
|
||||
|
||||
// DecimalPlaces returns the decimal places of the currency.
|
||||
func (c *Currency) DecimalPlaces() int32 {
|
||||
return c.decimalPlaces
|
||||
}
|
||||
|
||||
// Zero returns the zero value of the currency.
|
||||
func (c *Currency) Zero() Amount {
|
||||
return AmountFromBaseUnits(0, c)
|
||||
}
|
||||
|
||||
var (
|
||||
// StorjToken is the currency for the STORJ ERC20 token, which powers
|
||||
// most payments on the current Storj network.
|
||||
@ -169,6 +179,11 @@ func (a Amount) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(amountJSON)
|
||||
}
|
||||
|
||||
// IsNegative returns true if the base unit amount is negative.
|
||||
func (a Amount) IsNegative() bool {
|
||||
return a.baseUnits < 0
|
||||
}
|
||||
|
||||
// AmountFromBaseUnits creates a new Amount instance from the given count of
|
||||
// base units and in the given currency.
|
||||
func AmountFromBaseUnits(units int64, currency *Currency) Amount {
|
||||
@ -224,3 +239,25 @@ func DecimalFromBigFloat(f *big.Float) (decimal.Decimal, error) {
|
||||
dec, err := decimal.NewFromString(stringVal)
|
||||
return dec, Error.Wrap(err)
|
||||
}
|
||||
|
||||
// Add adds two monetary amounts and returns the result. If the currencies are different, an error is thrown.
|
||||
func Add(i, j Amount) (Amount, error) {
|
||||
if !sameCurrency(i.currency, j.currency) {
|
||||
return i.currency.Zero(), errs.New("Amounts to add must use the same currency")
|
||||
}
|
||||
return AmountFromBaseUnits(i.baseUnits+j.baseUnits, i.currency), nil
|
||||
}
|
||||
|
||||
// Greater returns true if the first monetary amount is greater than the second.
|
||||
// If the currencies are different, an error is thrown.
|
||||
func Greater(i, j Amount) (bool, error) {
|
||||
if !sameCurrency(i.currency, j.currency) {
|
||||
return false, errs.New("Amounts to compare must use the same currency")
|
||||
}
|
||||
return i.baseUnits > j.baseUnits, nil
|
||||
}
|
||||
|
||||
// returns true if the currencies are the same and not nil.
|
||||
func sameCurrency(i, j *Currency) bool {
|
||||
return i != nil && j != nil && i == j
|
||||
}
|
||||
|
@ -175,3 +175,33 @@ func TestAmountJSONUnmarshal(t *testing.T) {
|
||||
require.Equal(t, test.Currency, amount.Currency())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGreater(t *testing.T) {
|
||||
oneHundredUSDMicro := AmountFromBaseUnits(100000000, USDollarsMicro)
|
||||
twoHundredUSDMicro := AmountFromBaseUnits(200000000, USDollarsMicro)
|
||||
|
||||
checkGreater, err := Greater(twoHundredUSDMicro, oneHundredUSDMicro)
|
||||
require.NoError(t, err)
|
||||
require.True(t, checkGreater)
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
oneHundredUSD := AmountFromBaseUnits(10000, USDollars)
|
||||
twoHundredUSD := AmountFromBaseUnits(20000, USDollars)
|
||||
|
||||
sum, err := Add(oneHundredUSD, oneHundredUSD)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, twoHundredUSD, sum)
|
||||
}
|
||||
|
||||
func TestZero(t *testing.T) {
|
||||
require.Equal(t, AmountFromBaseUnits(0, USDollarsMicro), USDollarsMicro.Zero())
|
||||
require.Equal(t, AmountFromBaseUnits(0, USDollars), USDollars.Zero())
|
||||
require.NotEqual(t, AmountFromBaseUnits(0, USDollars), USDollarsMicro.Zero())
|
||||
}
|
||||
|
||||
func TestNegative(t *testing.T) {
|
||||
require.False(t, USDollarsMicro.Zero().IsNegative())
|
||||
require.False(t, USDollars.Zero().IsNegative())
|
||||
require.True(t, AmountFromBaseUnits(-10000, USDollars).IsNegative())
|
||||
}
|
||||
|
@ -557,12 +557,15 @@ func (service *Service) InvoiceApplyTokenBalance(ctx context.Context) (err error
|
||||
|
||||
for _, wallet := range wallets {
|
||||
// get the user token balance, if it's not > 0, don't bother with the rest
|
||||
tokenBalance, err := service.billingDB.GetBalance(ctx, wallet.UserID)
|
||||
monetaryTokenBalance, err := service.billingDB.GetBalance(ctx, wallet.UserID)
|
||||
// truncate here since stripe only has cent level precision for invoices.
|
||||
// The users account balance will still maintain the full precision monetary value!
|
||||
tokenBalance := monetary.AmountFromDecimal(monetaryTokenBalance.AsDecimal().Truncate(2), monetary.USDollars)
|
||||
if err != nil {
|
||||
errGrp.Add(Error.New("unable to compute balance for user ID %s", wallet.UserID.String()))
|
||||
continue
|
||||
}
|
||||
if tokenBalance <= 0 {
|
||||
if tokenBalance.BaseUnits() <= 0 {
|
||||
continue
|
||||
}
|
||||
// get the stripe customer invoice balance
|
||||
@ -583,8 +586,8 @@ func (service *Service) InvoiceApplyTokenBalance(ctx context.Context) (err error
|
||||
}
|
||||
|
||||
var tokenCreditAmount int64
|
||||
if invoice.AmountDue >= tokenBalance {
|
||||
tokenCreditAmount = -tokenBalance
|
||||
if invoice.AmountDue >= tokenBalance.BaseUnits() {
|
||||
tokenCreditAmount = -tokenBalance.BaseUnits()
|
||||
} else {
|
||||
tokenCreditAmount = -invoice.AmountDue
|
||||
}
|
||||
|
@ -34,31 +34,36 @@ func (db billingDB) Insert(ctx context.Context, billingTX billing.Transaction) (
|
||||
var dbxTX *dbx.BillingTransaction
|
||||
var retryCount int
|
||||
for {
|
||||
balance, err := db.GetBalance(ctx, billingTX.UserID)
|
||||
oldBalance, err := db.GetBalance(ctx, billingTX.UserID)
|
||||
if err != nil {
|
||||
return 0, Error.Wrap(err)
|
||||
}
|
||||
if balance+billingTX.Amount.BaseUnits() < 0 {
|
||||
billingAmount := monetary.AmountFromDecimal(billingTX.Amount.AsDecimal().Truncate(monetary.USDollarsMicro.DecimalPlaces()), monetary.USDollarsMicro)
|
||||
newBalance, err := monetary.Add(oldBalance, billingAmount)
|
||||
if err != nil {
|
||||
return 0, Error.Wrap(err)
|
||||
}
|
||||
if newBalance.IsNegative() {
|
||||
return 0, billing.ErrInsufficientFunds
|
||||
}
|
||||
|
||||
err = db.db.WithTx(ctx, func(ctx context.Context, tx *dbx.Tx) error {
|
||||
updatedRow, err := tx.Update_BillingBalance_By_UserId_And_Balance(ctx,
|
||||
dbx.BillingBalance_UserId(billingTX.UserID[:]),
|
||||
dbx.BillingBalance_Balance(balance),
|
||||
dbx.BillingBalance_Balance(oldBalance.BaseUnits()),
|
||||
dbx.BillingBalance_Update_Fields{
|
||||
Balance: dbx.BillingBalance_Balance(balance + billingTX.Amount.BaseUnits()),
|
||||
Balance: dbx.BillingBalance_Balance(newBalance.BaseUnits()),
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
if updatedRow == nil {
|
||||
// Try an insert here, in case the user never had a record in the table.
|
||||
// If the user already had a record, and the balance was not as expected,
|
||||
// If the user already had a record, and the oldBalance was not as expected,
|
||||
// the insert will fail anyways.
|
||||
err = tx.CreateNoReturn_BillingBalance(ctx,
|
||||
dbx.BillingBalance_UserId(billingTX.UserID[:]),
|
||||
dbx.BillingBalance_Balance(balance+billingTX.Amount.BaseUnits()))
|
||||
dbx.BillingBalance_Balance(newBalance.BaseUnits()))
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -66,8 +71,8 @@ func (db billingDB) Insert(ctx context.Context, billingTX billing.Transaction) (
|
||||
|
||||
dbxTX, err = tx.Create_BillingTransaction(ctx,
|
||||
dbx.BillingTransaction_UserId(billingTX.UserID[:]),
|
||||
dbx.BillingTransaction_Amount(billingTX.Amount.BaseUnits()),
|
||||
dbx.BillingTransaction_Currency(monetary.USDollars.Symbol()),
|
||||
dbx.BillingTransaction_Amount(billingAmount.BaseUnits()),
|
||||
dbx.BillingTransaction_Currency(billingAmount.Currency().Symbol()),
|
||||
dbx.BillingTransaction_Description(billingTX.Description),
|
||||
dbx.BillingTransaction_Source(billingTX.Source),
|
||||
dbx.BillingTransaction_Status(string(billingTX.Status)),
|
||||
@ -155,18 +160,18 @@ func (db billingDB) List(ctx context.Context, userID uuid.UUID) (txs []billing.T
|
||||
return txs, nil
|
||||
}
|
||||
|
||||
func (db billingDB) GetBalance(ctx context.Context, userID uuid.UUID) (_ int64, err error) {
|
||||
func (db billingDB) GetBalance(ctx context.Context, userID uuid.UUID) (_ monetary.Amount, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
dbxBilling, err := db.db.Get_BillingBalance_Balance_By_UserId(ctx,
|
||||
dbx.BillingBalance_UserId(userID[:]))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
return monetary.USDollarsMicro.Zero(), nil
|
||||
}
|
||||
return 0, Error.Wrap(err)
|
||||
return monetary.USDollarsMicro.Zero(), Error.Wrap(err)
|
||||
}
|
||||
|
||||
return dbxBilling.Balance, nil
|
||||
return monetary.AmountFromBaseUnits(dbxBilling.Balance, monetary.USDollarsMicro), nil
|
||||
}
|
||||
|
||||
// fromDBXBillingTransaction converts *dbx.BillingTransaction to *billing.Transaction.
|
||||
@ -178,7 +183,7 @@ func fromDBXBillingTransaction(dbxTX *dbx.BillingTransaction) (*billing.Transact
|
||||
return &billing.Transaction{
|
||||
ID: dbxTX.Id,
|
||||
UserID: userID,
|
||||
Amount: monetary.AmountFromBaseUnits(dbxTX.Amount, monetary.USDollars),
|
||||
Amount: monetary.AmountFromBaseUnits(dbxTX.Amount, monetary.USDollarsMicro),
|
||||
Description: dbxTX.Description,
|
||||
Source: dbxTX.Source,
|
||||
Status: billing.TransactionStatus(dbxTX.Status),
|
||||
|
Loading…
Reference in New Issue
Block a user