satellite/payments: add Balances interface
Add and implement interface to manage customer balances. Adds ability to add credit to a user's balance, list balance transactions, and get the balance. Change-Id: I7fd65d07868bb2b7489d1141a5e9049514d6984e
This commit is contained in:
parent
709dc63d42
commit
a2e3247471
@ -330,7 +330,7 @@ func (payment Payments) AccountBalance(ctx context.Context) (balance payments.Ba
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
return payment.service.accounts.Balance(ctx, user.ID)
|
||||
return payment.service.accounts.Balances().Get(ctx, user.ID)
|
||||
}
|
||||
|
||||
// AddCreditCard is used to save new credit card and attach it to payment account.
|
||||
@ -3179,6 +3179,38 @@ func (payment Payments) UpdatePackage(ctx context.Context, packagePlan string, p
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyCredit applies a credit of `amount` with description of `desc` to the user's balance. `amount` is in cents USD.
|
||||
// If a credit with `desc` already exists, another one will not be created.
|
||||
func (payment Payments) ApplyCredit(ctx context.Context, amount int64, desc string) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if desc == "" {
|
||||
return ErrPurchaseDesc.New("description cannot be empty")
|
||||
}
|
||||
user, err := GetUser(ctx)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
btxs, err := payment.service.accounts.Balances().ListTransactions(ctx, user.ID)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
// check for any previously created transaction with the same description.
|
||||
for _, btx := range btxs {
|
||||
if btx.Description == desc {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err = payment.service.accounts.Balances().ApplyCredit(ctx, user.ID, amount, desc)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProjectUsagePriceModel returns the project usage price model for the user.
|
||||
func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
@ -530,6 +530,37 @@ func TestService(t *testing.T) {
|
||||
|
||||
check()
|
||||
})
|
||||
t.Run("ApplyCredit fails when payments.Balances.ApplyCredit returns an error", func(t *testing.T) {
|
||||
require.Error(t, service.Payments().ApplyCredit(userCtx1, 1000, stripecoinpayments.MockCBTXsNewFailure))
|
||||
btxs, err := sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, len(btxs))
|
||||
})
|
||||
t.Run("ApplyCredit", func(t *testing.T) {
|
||||
amount := int64(1000)
|
||||
desc := "test"
|
||||
require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, desc))
|
||||
btxs, err := sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, btxs, 1)
|
||||
require.Equal(t, amount, btxs[0].Amount)
|
||||
require.Equal(t, desc, btxs[0].Description)
|
||||
|
||||
// test same description results in no new credit
|
||||
require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, desc))
|
||||
btxs, err = sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, btxs, 1)
|
||||
|
||||
// test different description results in new credit
|
||||
require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, "new desc"))
|
||||
btxs, err = sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, btxs, 2)
|
||||
})
|
||||
t.Run("ApplyCredit fails with unknown user", func(t *testing.T) {
|
||||
require.Error(t, service.Payments().ApplyCredit(ctx, 1000, "test"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -29,8 +29,8 @@ type Accounts interface {
|
||||
// GetPackageInfo returns the package plan and time of purchase for a user.
|
||||
GetPackageInfo(ctx context.Context, userID uuid.UUID) (packagePlan *string, purchaseTime *time.Time, err error)
|
||||
|
||||
// Balance returns an object that represents current free credits and coins balance in cents.
|
||||
Balance(ctx context.Context, userID uuid.UUID) (Balance, error)
|
||||
// Balances exposes functionality to manage account balances.
|
||||
Balances() Balances
|
||||
|
||||
// ProjectCharges returns how much money current user will be charged for each project.
|
||||
ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error)
|
||||
|
@ -4,9 +4,23 @@
|
||||
package payments
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
)
|
||||
|
||||
// Balances exposes needed functionality for managing customer balances.
|
||||
type Balances interface {
|
||||
// ApplyCredit applies a credit of `amount` to the user's stripe balance with a description of `desc`.
|
||||
ApplyCredit(ctx context.Context, userID uuid.UUID, amount int64, desc string) (*Balance, error)
|
||||
// Get returns the customer balance.
|
||||
Get(ctx context.Context, userID uuid.UUID) (Balance, error)
|
||||
// ListTransactions returns a list of transactions on the customer's balance.
|
||||
ListTransactions(ctx context.Context, userID uuid.UUID) ([]BalanceTransaction, error)
|
||||
}
|
||||
|
||||
// Balance is an entity that holds free credits and coins balance of user.
|
||||
// Earned by applying of promotional coupon and coins depositing, respectively.
|
||||
type Balance struct {
|
||||
@ -19,3 +33,10 @@ type Balance struct {
|
||||
// 4. bonus manually credited for a storjscan payment once a month before invoicing.
|
||||
// 5. any other adjustment we may have to make from time to time manually to the customer´s STORJ balance.
|
||||
}
|
||||
|
||||
// BalanceTransaction represents a single transaction affecting a customer balance.
|
||||
type BalanceTransaction struct {
|
||||
ID string
|
||||
Amount int64
|
||||
Description string
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stripe/stripe-go/v72"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
@ -31,6 +30,11 @@ func (accounts *accounts) CreditCards() payments.CreditCards {
|
||||
return &creditCards{service: accounts.service}
|
||||
}
|
||||
|
||||
// Balances exposes all needed functionality to manage account balances.
|
||||
func (accounts *accounts) Balances() payments.Balances {
|
||||
return &balances{service: accounts.service}
|
||||
}
|
||||
|
||||
// Invoices exposes all needed functionality to manage account invoices.
|
||||
func (accounts *accounts) Invoices() payments.Invoices {
|
||||
return &invoices{service: accounts.service}
|
||||
@ -121,34 +125,6 @@ func (accounts *accounts) GetPackageInfo(ctx context.Context, userID uuid.UUID)
|
||||
return
|
||||
}
|
||||
|
||||
// Balance returns an integer amount in cents that represents the current balance of payment account.
|
||||
func (accounts *accounts) Balance(ctx context.Context, userID uuid.UUID) (_ payments.Balance, err error) {
|
||||
defer mon.Task()(&ctx, userID)(&err)
|
||||
|
||||
balance, err := accounts.service.billingDB.GetBalance(ctx, userID)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
customerID, err := accounts.service.db.Customers().GetCustomerID(ctx, userID)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
params := &stripe.CustomerParams{Params: stripe.Params{Context: ctx}}
|
||||
customer, err := accounts.service.stripeClient.Customers().Get(customerID, params)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
// customer.Balance is negative if the user has a balance with us.
|
||||
// https://stripe.com/docs/api/customers/object#customer_object-balance
|
||||
return payments.Balance{
|
||||
Coins: balance.AsDecimal(),
|
||||
Credits: decimal.NewFromInt(-customer.Balance),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProjectCharges returns how much money current user will be charged for each project.
|
||||
func (accounts *accounts) ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) (charges []payments.ProjectCharge, err error) {
|
||||
defer mon.Task()(&ctx, userID, since, before)(&err)
|
||||
|
99
satellite/payments/stripecoinpayments/balances.go
Normal file
99
satellite/payments/stripecoinpayments/balances.go
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package stripecoinpayments
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stripe/stripe-go/v72"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/payments"
|
||||
)
|
||||
|
||||
type balances struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
// ApplyCredit applies a credit of `amount` to the user's stripe balance with a description of `desc`.
|
||||
func (balances *balances) ApplyCredit(ctx context.Context, userID uuid.UUID, amount int64, desc string) (b *payments.Balance, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
// NB: In stripe a negative amount means the customer is owed money.
|
||||
cbtx, err := balances.service.stripeClient.CustomerBalanceTransactions().New(&stripe.CustomerBalanceTransactionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
Description: stripe.String(desc),
|
||||
Amount: stripe.Int64(-amount),
|
||||
Currency: stripe.String(string(stripe.CurrencyUSD)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
return &payments.Balance{
|
||||
Credits: decimal.NewFromInt(-cbtx.EndingBalance),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (balances *balances) ListTransactions(ctx context.Context, userID uuid.UUID) (_ []payments.BalanceTransaction, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
var list []payments.BalanceTransaction
|
||||
iter := balances.service.stripeClient.CustomerBalanceTransactions().List(&stripe.CustomerBalanceTransactionListParams{
|
||||
Customer: stripe.String(customerID),
|
||||
})
|
||||
for iter.Next() {
|
||||
stripeCBTX := iter.CustomerBalanceTransaction()
|
||||
if stripeCBTX != nil {
|
||||
list = append(list, payments.BalanceTransaction{
|
||||
ID: stripeCBTX.ID,
|
||||
Amount: -stripeCBTX.Amount,
|
||||
Description: stripeCBTX.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
if err = iter.Err(); err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// Get returns an integer amount in cents that represents the current balance of payment account.
|
||||
func (balances *balances) Get(ctx context.Context, userID uuid.UUID) (_ payments.Balance, err error) {
|
||||
defer mon.Task()(&ctx, userID)(&err)
|
||||
|
||||
b, err := balances.service.billingDB.GetBalance(ctx, userID)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
params := &stripe.CustomerParams{Params: stripe.Params{Context: ctx}}
|
||||
customer, err := balances.service.stripeClient.Customers().Get(customerID, params)
|
||||
if err != nil {
|
||||
return payments.Balance{}, Error.Wrap(err)
|
||||
}
|
||||
|
||||
// customer.Balance is negative if the user has a balance with us.
|
||||
// https://stripe.com/docs/api/customers/object#customer_object-balance
|
||||
return payments.Balance{
|
||||
Coins: b.AsDecimal(),
|
||||
Credits: decimal.NewFromInt(-customer.Balance),
|
||||
}, nil
|
||||
}
|
76
satellite/payments/stripecoinpayments/balances_test.go
Normal file
76
satellite/payments/stripecoinpayments/balances_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package stripecoinpayments_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite/payments/stripecoinpayments"
|
||||
)
|
||||
|
||||
func TestBalances(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
sat := planet.Satellites[0]
|
||||
userID := planet.Uplinks[0].Projects[0].Owner.ID
|
||||
balances := sat.API.Payments.Accounts.Balances()
|
||||
|
||||
tx1Amount := int64(1000)
|
||||
tx1Desc := "test description 1"
|
||||
|
||||
b, err := balances.ApplyCredit(ctx, userID, tx1Amount, tx1Desc)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(tx1Amount), b.Credits)
|
||||
|
||||
bal, err := balances.Get(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(tx1Amount), bal.Credits)
|
||||
|
||||
tx2Amount := int64(-1000)
|
||||
tx2Desc := "test description 2"
|
||||
endingBalance := tx1Amount + tx2Amount
|
||||
b, err = balances.ApplyCredit(ctx, userID, tx2Amount, tx2Desc)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(endingBalance), b.Credits)
|
||||
|
||||
bal, err = balances.Get(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(endingBalance), bal.Credits)
|
||||
|
||||
tx3Amount := int64(-1000)
|
||||
tx3Desc := "test description 3"
|
||||
endingBalance += tx3Amount
|
||||
b, err = balances.ApplyCredit(ctx, userID, tx3Amount, tx3Desc)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(endingBalance), b.Credits)
|
||||
|
||||
bal, err = balances.Get(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decimal.NewFromInt(endingBalance), bal.Credits)
|
||||
|
||||
list, err := balances.ListTransactions(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list, 3)
|
||||
require.Equal(t, tx1Amount, list[0].Amount)
|
||||
require.Equal(t, tx1Desc, list[0].Description)
|
||||
require.Equal(t, tx2Amount, list[1].Amount)
|
||||
require.Equal(t, tx2Desc, list[1].Description)
|
||||
require.Equal(t, tx3Amount, list[2].Amount)
|
||||
require.Equal(t, tx3Desc, list[2].Description)
|
||||
|
||||
b, err = balances.ApplyCredit(ctx, userID, tx2Amount, stripecoinpayments.MockCBTXsNewFailure)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, b)
|
||||
|
||||
list, err = balances.ListTransactions(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(list))
|
||||
})
|
||||
}
|
@ -47,6 +47,10 @@ const (
|
||||
// TestPaymentMethodsAttachFailure can be passed to creditCards.Add as the cardToken arg to cause
|
||||
// mockPaymentMethods.Attach to return an error.
|
||||
TestPaymentMethodsAttachFailure = "test_payment_methods_attach_failure"
|
||||
|
||||
// MockCBTXsNewFailure can be passed to mockCustomerBalanceTransactions.New as the `desc` argument to cause it
|
||||
// to return an error.
|
||||
MockCBTXsNewFailure = "mock_cbtxs_new_failure"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -704,6 +708,14 @@ func newMockCustomerBalanceTransactions(root *mockStripeState) *mockCustomerBala
|
||||
}
|
||||
|
||||
func (m *mockCustomerBalanceTransactions) New(params *stripe.CustomerBalanceTransactionParams) (*stripe.CustomerBalanceTransaction, error) {
|
||||
m.root.mu.Lock()
|
||||
defer m.root.mu.Unlock()
|
||||
|
||||
if params.Description != nil {
|
||||
if *params.Description == MockCBTXsNewFailure {
|
||||
return nil, &stripe.Error{}
|
||||
}
|
||||
}
|
||||
tx := &stripe.CustomerBalanceTransaction{
|
||||
Type: stripe.CustomerBalanceTransactionTypeAdjustment,
|
||||
Amount: *params.Amount,
|
||||
@ -712,11 +724,14 @@ func (m *mockCustomerBalanceTransactions) New(params *stripe.CustomerBalanceTran
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
|
||||
m.root.mu.Lock()
|
||||
defer m.root.mu.Unlock()
|
||||
|
||||
m.transactions[*params.Customer] = append(m.transactions[*params.Customer], tx)
|
||||
|
||||
for _, v := range m.root.customers.customers {
|
||||
if v.ID == *params.Customer {
|
||||
v.Balance += *params.Amount
|
||||
tx.EndingBalance = v.Balance
|
||||
}
|
||||
}
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user