satellite/{payment,console}: add endpoint to add card by pmID

This change introduces a new endpoint that allows adding credit cards
by payment method ID (pmID). The payment method would've already been
created by the frontend using the stripe payment element for example.

Issue: #6436

Change-Id: If9a3f4c98171e36623607968d1a12f29fa7627e9
This commit is contained in:
Wilfred Asomani 2023-10-25 10:41:12 +00:00 committed by Wilfred Asomani
parent 32e67e5fab
commit f7a95e0077
8 changed files with 247 additions and 30 deletions

View File

@ -234,6 +234,44 @@ func (p *Payments) AddCreditCard(w http.ResponseWriter, r *http.Request) {
}
}
// AddCardByPaymentMethodID is used to save new credit card and attach it to payment account.
// It uses payment method id instead of token.
func (p *Payments) AddCardByPaymentMethodID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var err error
defer mon.Task()(&ctx)(&err)
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
return
}
pmID := string(bodyBytes)
_, err = p.service.Payments().AddCardByPaymentMethodID(ctx, pmID)
if err != nil {
if console.ErrUnauthorized.Has(err) {
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
return
}
if stripe.ErrDuplicateCard.Has(err) {
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
return
}
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
return
}
err = p.triggerAttemptPayment(ctx)
if err != nil {
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
return
}
}
// ListCreditCards returns a list of credit cards for a given payment account.
func (p *Payments) ListCreditCards(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -605,6 +643,9 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
var err error
defer mon.Task()(&ctx)(&err)
// whether to use payment method id instead of token for adding card.
usePmID := r.URL.Query().Get("pmID") == "true"
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
@ -625,7 +666,14 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
return
}
card, err := p.service.Payments().AddCreditCard(ctx, token)
var addCardFunc func(context.Context, string) (payments.CreditCard, error)
if usePmID {
addCardFunc = p.service.Payments().AddCardByPaymentMethodID
} else {
addCardFunc = p.service.Payments().AddCreditCard
}
card, err := addCardFunc(ctx, token)
if err != nil {
switch {
case console.ErrUnauthorized.Has(err):

View File

@ -339,6 +339,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter()
paymentsRouter.Use(server.withCORS)
paymentsRouter.Use(server.withAuth)
paymentsRouter.Handle("/payment-methods", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCardByPaymentMethodID))).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.Handle("/cards", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCreditCard))).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.HandleFunc("/cards", paymentController.MakeCreditCardDefault).Methods(http.MethodPatch, http.MethodOptions)
paymentsRouter.HandleFunc("/cards", paymentController.ListCreditCards).Methods(http.MethodGet, http.MethodOptions)

View File

@ -348,43 +348,78 @@ func (payment Payments) AddCreditCard(ctx context.Context, creditCardToken strin
payment.service.analytics.TrackCreditCardAdded(user.ID, user.Email)
if !user.PaidTier {
// put this user into the paid tier and convert projects to upgraded limits.
err = payment.service.store.Users().UpdatePaidTier(ctx, user.ID, true,
payment.service.config.UsageLimits.Bandwidth.Paid,
payment.service.config.UsageLimits.Storage.Paid,
payment.service.config.UsageLimits.Segment.Paid,
payment.service.config.UsageLimits.Project.Paid,
)
err = payment.upgradeToPaidTier(ctx, user)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
projects, err := payment.service.store.Projects().GetOwn(ctx, user.ID)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
for _, project := range projects {
if project.StorageLimit == nil || *project.StorageLimit < payment.service.config.UsageLimits.Storage.Paid {
project.StorageLimit = new(memory.Size)
*project.StorageLimit = payment.service.config.UsageLimits.Storage.Paid
}
if project.BandwidthLimit == nil || *project.BandwidthLimit < payment.service.config.UsageLimits.Bandwidth.Paid {
project.BandwidthLimit = new(memory.Size)
*project.BandwidthLimit = payment.service.config.UsageLimits.Bandwidth.Paid
}
if project.SegmentLimit == nil || *project.SegmentLimit < payment.service.config.UsageLimits.Segment.Paid {
*project.SegmentLimit = payment.service.config.UsageLimits.Segment.Paid
}
err = payment.service.store.Projects().Update(ctx, &project)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
}
}
return card, nil
}
// AddCardByPaymentMethodID is used to save new credit card and attach it to payment account.
func (payment Payments) AddCardByPaymentMethodID(ctx context.Context, pmID string) (card payments.CreditCard, err error) {
defer mon.Task()(&ctx, pmID)(&err)
user, err := payment.service.getUserAndAuditLog(ctx, "add credit card")
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
card, err = payment.service.accounts.CreditCards().AddByPaymentMethodID(ctx, user.ID, pmID)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
payment.service.analytics.TrackCreditCardAdded(user.ID, user.Email)
if !user.PaidTier {
err = payment.upgradeToPaidTier(ctx, user)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
}
return card, nil
}
func (payment Payments) upgradeToPaidTier(ctx context.Context, user *User) (err error) {
// put this user into the paid tier and convert projects to upgraded limits.
err = payment.service.store.Users().UpdatePaidTier(ctx, user.ID, true,
payment.service.config.UsageLimits.Bandwidth.Paid,
payment.service.config.UsageLimits.Storage.Paid,
payment.service.config.UsageLimits.Segment.Paid,
payment.service.config.UsageLimits.Project.Paid,
)
if err != nil {
return Error.Wrap(err)
}
projects, err := payment.service.store.Projects().GetOwn(ctx, user.ID)
if err != nil {
return Error.Wrap(err)
}
for _, project := range projects {
if project.StorageLimit == nil || *project.StorageLimit < payment.service.config.UsageLimits.Storage.Paid {
project.StorageLimit = new(memory.Size)
*project.StorageLimit = payment.service.config.UsageLimits.Storage.Paid
}
if project.BandwidthLimit == nil || *project.BandwidthLimit < payment.service.config.UsageLimits.Bandwidth.Paid {
project.BandwidthLimit = new(memory.Size)
*project.BandwidthLimit = payment.service.config.UsageLimits.Bandwidth.Paid
}
if project.SegmentLimit == nil || *project.SegmentLimit < payment.service.config.UsageLimits.Segment.Paid {
*project.SegmentLimit = payment.service.config.UsageLimits.Segment.Paid
}
err = payment.service.store.Projects().Update(ctx, &project)
if err != nil {
return Error.Wrap(err)
}
}
return nil
}
// MakeCreditCardDefault makes a credit card default payment method.
func (payment Payments) MakeCreditCardDefault(ctx context.Context, cardID string) (err error) {
defer mon.Task()(&ctx, cardID)(&err)

View File

@ -19,6 +19,11 @@ type CreditCards interface {
// Add is used to save new credit card and attach it to payment account.
Add(ctx context.Context, userID uuid.UUID, cardToken string) (CreditCard, error)
// AddByPaymentMethodID is used to save new credit card, attach it to payment account and make it default
// using the payment method id instead of the token. In this case, the payment method should already be
// created by the frontend using stripe elements for example.
AddByPaymentMethodID(ctx context.Context, userID uuid.UUID, pmID string) (CreditCard, error)
// Remove is used to detach a credit card from payment account.
Remove(ctx context.Context, userID uuid.UUID, cardID string) error

View File

@ -51,6 +51,7 @@ type Customers interface {
type PaymentMethods interface {
List(listParams *stripe.PaymentMethodListParams) *paymentmethod.Iter
New(params *stripe.PaymentMethodParams) (*stripe.PaymentMethod, error)
Get(id string, params *stripe.PaymentMethodParams) (*stripe.PaymentMethod, error)
Attach(id string, params *stripe.PaymentMethodAttachParams) (*stripe.PaymentMethod, error)
Detach(id string, params *stripe.PaymentMethodDetachParams) (*stripe.PaymentMethod, error)
}

View File

@ -152,6 +152,75 @@ func (creditCards *creditCards) Add(ctx context.Context, userID uuid.UUID, cardT
}, Error.Wrap(err)
}
// AddByPaymentMethodID is used to save new credit card, attach it to payment account and make it default
// using the payment method id instead of the token. In this case, the payment method should already be
// created by the frontend using the stripe payment element for example.
func (creditCards *creditCards) AddByPaymentMethodID(ctx context.Context, userID uuid.UUID, pmID string) (_ payments.CreditCard, err error) {
defer mon.Task()(&ctx, userID, pmID)(&err)
customerID, err := creditCards.service.db.Customers().GetCustomerID(ctx, userID)
if err != nil {
return payments.CreditCard{}, payments.ErrAccountNotSetup.Wrap(err)
}
card, err := creditCards.service.stripeClient.PaymentMethods().Get(pmID, &stripe.PaymentMethodParams{
Params: stripe.Params{Context: ctx},
})
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
listParams := &stripe.PaymentMethodListParams{
ListParams: stripe.ListParams{Context: ctx},
Customer: &customerID,
Type: stripe.String(string(stripe.PaymentMethodTypeCard)),
}
paymentMethodsIterator := creditCards.service.stripeClient.PaymentMethods().List(listParams)
for paymentMethodsIterator.Next() {
stripeCard := paymentMethodsIterator.PaymentMethod()
if stripeCard.Card.Fingerprint == card.Card.Fingerprint &&
stripeCard.Card.ExpMonth == card.Card.ExpMonth &&
stripeCard.Card.ExpYear == card.Card.ExpYear {
return payments.CreditCard{}, ErrDuplicateCard.New("this card is already on file for your account.")
}
}
if err = paymentMethodsIterator.Err(); err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
attachParams := &stripe.PaymentMethodAttachParams{
Params: stripe.Params{Context: ctx},
Customer: &customerID,
}
card, err = creditCards.service.stripeClient.PaymentMethods().Attach(pmID, attachParams)
if err != nil {
return payments.CreditCard{}, Error.Wrap(err)
}
params := &stripe.CustomerParams{
Params: stripe.Params{Context: ctx},
InvoiceSettings: &stripe.CustomerInvoiceSettingsParams{
DefaultPaymentMethod: stripe.String(card.ID),
},
}
_, err = creditCards.service.stripeClient.Customers().Update(customerID, params)
// TODO: handle created but not attached card manually?
return payments.CreditCard{
ID: card.ID,
ExpMonth: int(card.Card.ExpMonth),
ExpYear: int(card.Card.ExpYear),
Brand: string(card.Card.Brand),
Last4: card.Card.Last4,
IsDefault: true,
}, Error.Wrap(err)
}
// MakeDefault makes a credit card default payment method.
// this credit card should be attached to account before make it default.
func (creditCards *creditCards) MakeDefault(ctx context.Context, userID uuid.UUID, cardID string) (err error) {

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
stripeLib "github.com/stripe/stripe-go/v75"
"storj.io/common/testcontext"
"storj.io/storj/private/testplanet"
@ -85,6 +86,42 @@ func TestCreditCards_Add(t *testing.T) {
})
}
func TestCreditCards_AddByPaymentMethodID(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
satellite := planet.Satellites[0]
u, err := satellite.AddUser(ctx, console.CreateUser{
FullName: "Test User",
Email: "test@storj.test",
}, 1)
require.NoError(t, err)
_, err = satellite.API.Payments.Accounts.CreditCards().AddByPaymentMethodID(ctx, u.ID, "non-existent")
require.Error(t, err)
pm, err := satellite.API.Payments.StripeClient.PaymentMethods().New(&stripeLib.PaymentMethodParams{
Type: stripeLib.String(string(stripeLib.PaymentMethodTypeCard)),
Card: &stripeLib.PaymentMethodCardParams{
Token: stripeLib.String("test"),
},
})
require.NoError(t, err)
_, err = satellite.API.Payments.Accounts.CreditCards().AddByPaymentMethodID(ctx, u.ID, pm.ID)
require.NoError(t, err)
_, err = satellite.API.Payments.Accounts.CreditCards().AddByPaymentMethodID(ctx, u.ID, pm.ID)
require.Error(t, err)
require.True(t, stripe.ErrDuplicateCard.Has(err))
cards, err := satellite.API.Payments.Accounts.CreditCards().List(ctx, u.ID)
require.NoError(t, err)
require.Len(t, cards, 1)
})
}
func TestCreditCards_AddDuplicateCard(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1,

View File

@ -424,6 +424,27 @@ func (m *mockPaymentMethods) List(listParams *stripe.PaymentMethodListParams) *p
return &paymentmethod.Iter{Iter: stripe.GetIter(nil, query)}
}
func (m *mockPaymentMethods) Get(id string, params *stripe.PaymentMethodParams) (*stripe.PaymentMethod, error) {
m.root.mu.Lock()
defer m.root.mu.Unlock()
for _, method := range m.unattached {
if method.ID == id {
return method, nil
}
}
for _, methods := range m.attached {
for _, method := range methods {
if method.ID == id {
return method, nil
}
}
}
return nil, errors.New("payment method not found")
}
func (m *mockPaymentMethods) New(params *stripe.PaymentMethodParams) (*stripe.PaymentMethod, error) {
randID := testrand.BucketName()
id := fmt.Sprintf("pm_card_%s", randID)