storj/satellite/satellitedb/customers.go
Michal Niewrzal a21afeddd1 satellite/payments/stripe: avoid full table scan while listing
Query to list (with pages) stripe customers were doing full table scan
because Offset clause was used. This refactoring changed listing to
use cursor instead Offset.

Change-Id: I14688e6c533bc932ba0d209a061562f080b4cf54
2023-04-13 12:36:31 +02:00

178 lines
5.1 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"context"
"database/sql"
"errors"
"time"
"github.com/zeebo/errs"
"storj.io/common/uuid"
"storj.io/storj/satellite/payments/stripe"
"storj.io/storj/satellite/satellitedb/dbx"
)
// ensures that customers implements stripecoinpayments.CustomersDB.
var _ stripe.CustomersDB = (*customers)(nil)
// customers is an implementation of stripecoinpayments.CustomersDB.
//
// architecture: Database
type customers struct {
db *satelliteDB
}
// Raw returns the raw dbx handle.
func (customers *customers) Raw() *dbx.DB {
return customers.db.DB
}
// Insert inserts a stripe customer into the database.
func (customers *customers) Insert(ctx context.Context, userID uuid.UUID, customerID string) (err error) {
defer mon.Task()(&ctx, userID, customerID)(&err)
_, err = customers.db.Create_StripeCustomer(
ctx,
dbx.StripeCustomer_UserId(userID[:]),
dbx.StripeCustomer_CustomerId(customerID),
dbx.StripeCustomer_Create_Fields{},
)
return err
}
// GetCustomerID returns stripe customers id.
func (customers *customers) GetCustomerID(ctx context.Context, userID uuid.UUID) (_ string, err error) {
defer mon.Task()(&ctx, userID)(&err)
idRow, err := customers.db.Get_StripeCustomer_CustomerId_By_UserId(ctx, dbx.StripeCustomer_UserId(userID[:]))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", stripe.ErrNoCustomer
}
return "", err
}
return idRow.CustomerId, nil
}
// GetUserID return userID given stripe customer id.
func (customers *customers) GetUserID(ctx context.Context, customerID string) (_ uuid.UUID, err error) {
defer mon.Task()(&ctx)(&err)
idRow, err := customers.db.Get_StripeCustomer_UserId_By_CustomerId(ctx, dbx.StripeCustomer_CustomerId(customerID))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return uuid.UUID{}, stripe.ErrNoCustomer
}
return uuid.UUID{}, err
}
return uuid.FromBytes(idRow.UserId)
}
// List returns paginated customers id list, with customers created before specified date.
func (customers *customers) List(ctx context.Context, userIDCursor uuid.UUID, limit int, before time.Time) (page stripe.CustomersPage, err error) {
defer mon.Task()(&ctx)(&err)
rows, err := customers.db.QueryContext(ctx, customers.db.Rebind(`
SELECT
stripe_customers.user_id, stripe_customers.customer_id
FROM
stripe_customers
WHERE
stripe_customers.user_id > ? AND
stripe_customers.created_at < ?
ORDER BY stripe_customers.user_id ASC
LIMIT ?
`), userIDCursor, before, limit+1)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return stripe.CustomersPage{}, nil
}
return stripe.CustomersPage{}, err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
results := []stripe.Customer{}
for rows.Next() {
var customer stripe.Customer
err := rows.Scan(&customer.UserID, &customer.ID)
if err != nil {
return stripe.CustomersPage{}, errs.New("unable to get stripe customer: %+v", err)
}
results = append(results, customer)
}
if err := rows.Err(); err != nil {
return stripe.CustomersPage{}, errs.New("error while listing stripe customers: %+v", err)
}
if len(results) == limit+1 {
results = results[:len(results)-1]
page.Next = true
page.Cursor = results[len(results)-1].UserID
}
page.Customers = results
return page, nil
}
// UpdatePackage updates the customer's package plan and purchase time.
func (customers *customers) UpdatePackage(ctx context.Context, userID uuid.UUID, packagePlan *string, timestamp *time.Time) (c *stripe.Customer, err error) {
defer mon.Task()(&ctx)(&err)
updateFields := dbx.StripeCustomer_Update_Fields{
PackagePlan: dbx.StripeCustomer_PackagePlan_Null(),
PurchasedPackageAt: dbx.StripeCustomer_PurchasedPackageAt_Null(),
}
if packagePlan != nil {
updateFields.PackagePlan = dbx.StripeCustomer_PackagePlan(*packagePlan)
}
if timestamp != nil {
updateFields.PurchasedPackageAt = dbx.StripeCustomer_PurchasedPackageAt(*timestamp)
}
dbxCustomer, err := customers.db.Update_StripeCustomer_By_UserId(ctx,
dbx.StripeCustomer_UserId(userID[:]),
updateFields,
)
if err != nil {
return c, err
}
return fromDBXCustomer(dbxCustomer)
}
// UpdatePackage updates the customer's package plan and purchase time.
func (customers *customers) GetPackageInfo(ctx context.Context, userID uuid.UUID) (_ *string, _ *time.Time, err error) {
defer mon.Task()(&ctx)(&err)
row, err := customers.db.Get_StripeCustomer_PackagePlan_StripeCustomer_PurchasedPackageAt_By_UserId(ctx, dbx.StripeCustomer_UserId(userID[:]))
if err != nil {
return nil, nil, err
}
return row.PackagePlan, row.PurchasedPackageAt, nil
}
// fromDBXCustomer converts *dbx.StripeCustomer to *stripecoinpayments.Customer.
func fromDBXCustomer(dbxCustomer *dbx.StripeCustomer) (*stripe.Customer, error) {
userID, err := uuid.FromBytes(dbxCustomer.UserId)
if err != nil {
return nil, err
}
return &stripe.Customer{
ID: dbxCustomer.CustomerId,
UserID: userID,
PackagePlan: dbxCustomer.PackagePlan,
PackagePurchasedAt: dbxCustomer.PurchasedPackageAt,
}, nil
}