diff --git a/satellite/console/service.go b/satellite/console/service.go index bf776fb53..b639ca6b3 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -330,7 +330,7 @@ func (payment Payments) AccountBalance(ctx context.Context) (balance payments.Ba return payments.Balance{}, Error.Wrap(err) } - return payment.service.accounts.Balance(ctx, user.ID) + return payment.service.accounts.Balances().Get(ctx, user.ID) } // AddCreditCard is used to save new credit card and attach it to payment account. @@ -3179,6 +3179,38 @@ func (payment Payments) UpdatePackage(ctx context.Context, packagePlan string, p return nil } +// ApplyCredit applies a credit of `amount` with description of `desc` to the user's balance. `amount` is in cents USD. +// If a credit with `desc` already exists, another one will not be created. +func (payment Payments) ApplyCredit(ctx context.Context, amount int64, desc string) (err error) { + defer mon.Task()(&ctx)(&err) + + if desc == "" { + return ErrPurchaseDesc.New("description cannot be empty") + } + user, err := GetUser(ctx) + if err != nil { + return Error.Wrap(err) + } + + btxs, err := payment.service.accounts.Balances().ListTransactions(ctx, user.ID) + if err != nil { + return Error.Wrap(err) + } + + // check for any previously created transaction with the same description. + for _, btx := range btxs { + if btx.Description == desc { + return nil + } + } + + _, err = payment.service.accounts.Balances().ApplyCredit(ctx, user.ID, amount, desc) + if err != nil { + return Error.Wrap(err) + } + return nil +} + // GetProjectUsagePriceModel returns the project usage price model for the user. func (payment Payments) GetProjectUsagePriceModel(ctx context.Context) (_ *payments.ProjectUsagePriceModel, err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/console/service_test.go b/satellite/console/service_test.go index 3d27e2ec6..d758096be 100644 --- a/satellite/console/service_test.go +++ b/satellite/console/service_test.go @@ -530,6 +530,37 @@ func TestService(t *testing.T) { check() }) + t.Run("ApplyCredit fails when payments.Balances.ApplyCredit returns an error", func(t *testing.T) { + require.Error(t, service.Payments().ApplyCredit(userCtx1, 1000, stripecoinpayments.MockCBTXsNewFailure)) + btxs, err := sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID) + require.NoError(t, err) + require.Zero(t, len(btxs)) + }) + t.Run("ApplyCredit", func(t *testing.T) { + amount := int64(1000) + desc := "test" + require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, desc)) + btxs, err := sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID) + require.NoError(t, err) + require.Len(t, btxs, 1) + require.Equal(t, amount, btxs[0].Amount) + require.Equal(t, desc, btxs[0].Description) + + // test same description results in no new credit + require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, desc)) + btxs, err = sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID) + require.NoError(t, err) + require.Len(t, btxs, 1) + + // test different description results in new credit + require.NoError(t, service.Payments().ApplyCredit(userCtx1, 1000, "new desc")) + btxs, err = sat.API.Payments.Accounts.Balances().ListTransactions(ctx, up1Pro1.OwnerID) + require.NoError(t, err) + require.Len(t, btxs, 2) + }) + t.Run("ApplyCredit fails with unknown user", func(t *testing.T) { + require.Error(t, service.Payments().ApplyCredit(ctx, 1000, "test")) + }) }) } diff --git a/satellite/payments/account.go b/satellite/payments/account.go index 1488fd1fd..42703897b 100644 --- a/satellite/payments/account.go +++ b/satellite/payments/account.go @@ -29,8 +29,8 @@ type Accounts interface { // GetPackageInfo returns the package plan and time of purchase for a user. GetPackageInfo(ctx context.Context, userID uuid.UUID) (packagePlan *string, purchaseTime *time.Time, err error) - // Balance returns an object that represents current free credits and coins balance in cents. - Balance(ctx context.Context, userID uuid.UUID) (Balance, error) + // Balances exposes functionality to manage account balances. + Balances() Balances // ProjectCharges returns how much money current user will be charged for each project. ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) ([]ProjectCharge, error) diff --git a/satellite/payments/balance.go b/satellite/payments/balance.go index 8bef2ae01..ca38d15b1 100644 --- a/satellite/payments/balance.go +++ b/satellite/payments/balance.go @@ -4,9 +4,23 @@ package payments import ( + "context" + "github.com/shopspring/decimal" + + "storj.io/common/uuid" ) +// Balances exposes needed functionality for managing customer balances. +type Balances interface { + // ApplyCredit applies a credit of `amount` to the user's stripe balance with a description of `desc`. + ApplyCredit(ctx context.Context, userID uuid.UUID, amount int64, desc string) (*Balance, error) + // Get returns the customer balance. + Get(ctx context.Context, userID uuid.UUID) (Balance, error) + // ListTransactions returns a list of transactions on the customer's balance. + ListTransactions(ctx context.Context, userID uuid.UUID) ([]BalanceTransaction, error) +} + // Balance is an entity that holds free credits and coins balance of user. // Earned by applying of promotional coupon and coins depositing, respectively. type Balance struct { @@ -19,3 +33,10 @@ type Balance struct { // 4. bonus manually credited for a storjscan payment once a month before invoicing. // 5. any other adjustment we may have to make from time to time manually to the customerĀ“s STORJ balance. } + +// BalanceTransaction represents a single transaction affecting a customer balance. +type BalanceTransaction struct { + ID string + Amount int64 + Description string +} diff --git a/satellite/payments/stripecoinpayments/accounts.go b/satellite/payments/stripecoinpayments/accounts.go index 13cb5936b..d029bcf0b 100644 --- a/satellite/payments/stripecoinpayments/accounts.go +++ b/satellite/payments/stripecoinpayments/accounts.go @@ -7,7 +7,6 @@ import ( "context" "time" - "github.com/shopspring/decimal" "github.com/stripe/stripe-go/v72" "github.com/zeebo/errs" @@ -31,6 +30,11 @@ func (accounts *accounts) CreditCards() payments.CreditCards { return &creditCards{service: accounts.service} } +// Balances exposes all needed functionality to manage account balances. +func (accounts *accounts) Balances() payments.Balances { + return &balances{service: accounts.service} +} + // Invoices exposes all needed functionality to manage account invoices. func (accounts *accounts) Invoices() payments.Invoices { return &invoices{service: accounts.service} @@ -121,34 +125,6 @@ func (accounts *accounts) GetPackageInfo(ctx context.Context, userID uuid.UUID) return } -// Balance returns an integer amount in cents that represents the current balance of payment account. -func (accounts *accounts) Balance(ctx context.Context, userID uuid.UUID) (_ payments.Balance, err error) { - defer mon.Task()(&ctx, userID)(&err) - - balance, err := accounts.service.billingDB.GetBalance(ctx, userID) - if err != nil { - return payments.Balance{}, Error.Wrap(err) - } - - customerID, err := accounts.service.db.Customers().GetCustomerID(ctx, userID) - if err != nil { - return payments.Balance{}, Error.Wrap(err) - } - - params := &stripe.CustomerParams{Params: stripe.Params{Context: ctx}} - customer, err := accounts.service.stripeClient.Customers().Get(customerID, params) - if err != nil { - return payments.Balance{}, Error.Wrap(err) - } - - // customer.Balance is negative if the user has a balance with us. - // https://stripe.com/docs/api/customers/object#customer_object-balance - return payments.Balance{ - Coins: balance.AsDecimal(), - Credits: decimal.NewFromInt(-customer.Balance), - }, nil -} - // ProjectCharges returns how much money current user will be charged for each project. func (accounts *accounts) ProjectCharges(ctx context.Context, userID uuid.UUID, since, before time.Time) (charges []payments.ProjectCharge, err error) { defer mon.Task()(&ctx, userID, since, before)(&err) diff --git a/satellite/payments/stripecoinpayments/balances.go b/satellite/payments/stripecoinpayments/balances.go new file mode 100644 index 000000000..46e2f496d --- /dev/null +++ b/satellite/payments/stripecoinpayments/balances.go @@ -0,0 +1,99 @@ +// Copyright (C) 2023 Storj Labs, Inc. +// See LICENSE for copying information. + +package stripecoinpayments + +import ( + "context" + + "github.com/shopspring/decimal" + "github.com/stripe/stripe-go/v72" + + "storj.io/common/uuid" + "storj.io/storj/satellite/payments" +) + +type balances struct { + service *Service +} + +// ApplyCredit applies a credit of `amount` to the user's stripe balance with a description of `desc`. +func (balances *balances) ApplyCredit(ctx context.Context, userID uuid.UUID, amount int64, desc string) (b *payments.Balance, err error) { + defer mon.Task()(&ctx)(&err) + + customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID) + if err != nil { + return nil, Error.Wrap(err) + } + + // NB: In stripe a negative amount means the customer is owed money. + cbtx, err := balances.service.stripeClient.CustomerBalanceTransactions().New(&stripe.CustomerBalanceTransactionParams{ + Customer: stripe.String(customerID), + Description: stripe.String(desc), + Amount: stripe.Int64(-amount), + Currency: stripe.String(string(stripe.CurrencyUSD)), + }) + if err != nil { + return nil, Error.Wrap(err) + } + + return &payments.Balance{ + Credits: decimal.NewFromInt(-cbtx.EndingBalance), + }, nil +} + +func (balances *balances) ListTransactions(ctx context.Context, userID uuid.UUID) (_ []payments.BalanceTransaction, err error) { + defer mon.Task()(&ctx)(&err) + + customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID) + if err != nil { + return nil, Error.Wrap(err) + } + + var list []payments.BalanceTransaction + iter := balances.service.stripeClient.CustomerBalanceTransactions().List(&stripe.CustomerBalanceTransactionListParams{ + Customer: stripe.String(customerID), + }) + for iter.Next() { + stripeCBTX := iter.CustomerBalanceTransaction() + if stripeCBTX != nil { + list = append(list, payments.BalanceTransaction{ + ID: stripeCBTX.ID, + Amount: -stripeCBTX.Amount, + Description: stripeCBTX.Description, + }) + } + } + if err = iter.Err(); err != nil { + return nil, Error.Wrap(err) + } + return list, nil +} + +// Get returns an integer amount in cents that represents the current balance of payment account. +func (balances *balances) Get(ctx context.Context, userID uuid.UUID) (_ payments.Balance, err error) { + defer mon.Task()(&ctx, userID)(&err) + + b, err := balances.service.billingDB.GetBalance(ctx, userID) + if err != nil { + return payments.Balance{}, Error.Wrap(err) + } + + customerID, err := balances.service.db.Customers().GetCustomerID(ctx, userID) + if err != nil { + return payments.Balance{}, Error.Wrap(err) + } + + params := &stripe.CustomerParams{Params: stripe.Params{Context: ctx}} + customer, err := balances.service.stripeClient.Customers().Get(customerID, params) + if err != nil { + return payments.Balance{}, Error.Wrap(err) + } + + // customer.Balance is negative if the user has a balance with us. + // https://stripe.com/docs/api/customers/object#customer_object-balance + return payments.Balance{ + Coins: b.AsDecimal(), + Credits: decimal.NewFromInt(-customer.Balance), + }, nil +} diff --git a/satellite/payments/stripecoinpayments/balances_test.go b/satellite/payments/stripecoinpayments/balances_test.go new file mode 100644 index 000000000..d2766c430 --- /dev/null +++ b/satellite/payments/stripecoinpayments/balances_test.go @@ -0,0 +1,76 @@ +// Copyright (C) 2023 Storj Labs, Inc. +// See LICENSE for copying information. + +package stripecoinpayments_test + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + + "storj.io/common/testcontext" + "storj.io/storj/private/testplanet" + "storj.io/storj/satellite/payments/stripecoinpayments" +) + +func TestBalances(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, UplinkCount: 1, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + sat := planet.Satellites[0] + userID := planet.Uplinks[0].Projects[0].Owner.ID + balances := sat.API.Payments.Accounts.Balances() + + tx1Amount := int64(1000) + tx1Desc := "test description 1" + + b, err := balances.ApplyCredit(ctx, userID, tx1Amount, tx1Desc) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(tx1Amount), b.Credits) + + bal, err := balances.Get(ctx, userID) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(tx1Amount), bal.Credits) + + tx2Amount := int64(-1000) + tx2Desc := "test description 2" + endingBalance := tx1Amount + tx2Amount + b, err = balances.ApplyCredit(ctx, userID, tx2Amount, tx2Desc) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(endingBalance), b.Credits) + + bal, err = balances.Get(ctx, userID) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(endingBalance), bal.Credits) + + tx3Amount := int64(-1000) + tx3Desc := "test description 3" + endingBalance += tx3Amount + b, err = balances.ApplyCredit(ctx, userID, tx3Amount, tx3Desc) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(endingBalance), b.Credits) + + bal, err = balances.Get(ctx, userID) + require.NoError(t, err) + require.Equal(t, decimal.NewFromInt(endingBalance), bal.Credits) + + list, err := balances.ListTransactions(ctx, userID) + require.NoError(t, err) + require.Len(t, list, 3) + require.Equal(t, tx1Amount, list[0].Amount) + require.Equal(t, tx1Desc, list[0].Description) + require.Equal(t, tx2Amount, list[1].Amount) + require.Equal(t, tx2Desc, list[1].Description) + require.Equal(t, tx3Amount, list[2].Amount) + require.Equal(t, tx3Desc, list[2].Description) + + b, err = balances.ApplyCredit(ctx, userID, tx2Amount, stripecoinpayments.MockCBTXsNewFailure) + require.Error(t, err) + require.Nil(t, b) + + list, err = balances.ListTransactions(ctx, userID) + require.NoError(t, err) + require.Equal(t, 3, len(list)) + }) +} diff --git a/satellite/payments/stripecoinpayments/stripemock.go b/satellite/payments/stripecoinpayments/stripemock.go index 049bff1ed..9299bd861 100644 --- a/satellite/payments/stripecoinpayments/stripemock.go +++ b/satellite/payments/stripecoinpayments/stripemock.go @@ -47,6 +47,10 @@ const ( // TestPaymentMethodsAttachFailure can be passed to creditCards.Add as the cardToken arg to cause // mockPaymentMethods.Attach to return an error. TestPaymentMethodsAttachFailure = "test_payment_methods_attach_failure" + + // MockCBTXsNewFailure can be passed to mockCustomerBalanceTransactions.New as the `desc` argument to cause it + // to return an error. + MockCBTXsNewFailure = "mock_cbtxs_new_failure" ) var ( @@ -704,6 +708,14 @@ func newMockCustomerBalanceTransactions(root *mockStripeState) *mockCustomerBala } func (m *mockCustomerBalanceTransactions) New(params *stripe.CustomerBalanceTransactionParams) (*stripe.CustomerBalanceTransaction, error) { + m.root.mu.Lock() + defer m.root.mu.Unlock() + + if params.Description != nil { + if *params.Description == MockCBTXsNewFailure { + return nil, &stripe.Error{} + } + } tx := &stripe.CustomerBalanceTransaction{ Type: stripe.CustomerBalanceTransactionTypeAdjustment, Amount: *params.Amount, @@ -712,11 +724,14 @@ func (m *mockCustomerBalanceTransactions) New(params *stripe.CustomerBalanceTran Created: time.Now().Unix(), } - m.root.mu.Lock() - defer m.root.mu.Unlock() - m.transactions[*params.Customer] = append(m.transactions[*params.Customer], tx) + for _, v := range m.root.customers.customers { + if v.ID == *params.Customer { + v.Balance += *params.Amount + tx.EndingBalance = v.Balance + } + } return tx, nil }