satellite/payments/stripe: avoid full table scan while listing records

Stripe invoice project records while listing are causing full table scan
because of OFFSET caluse. This change is refactoring query to list using
cursor.

Change-Id: I6b73b9b2815173d7ef02cf615408778476eb3b7b
This commit is contained in:
Michal Niewrzal 2023-05-08 13:15:09 +02:00 committed by Storj Robot
parent c64f3f3132
commit 87d0789691
8 changed files with 81 additions and 188 deletions

View File

@ -26,7 +26,8 @@ type ProjectRecordsDB interface {
// Consume consumes invoice project record.
Consume(ctx context.Context, id uuid.UUID) error
// ListUnapplied returns project records page with unapplied project records.
ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (ProjectRecordsPage, error)
// Cursor is not included into listing results.
ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error)
}
// CreateProjectRecord holds info needed for creation new invoice
@ -52,9 +53,9 @@ type ProjectRecord struct {
// ProjectRecordsPage holds project records and
// indicates if there is more data available
// and provides next offset.
// and provides cursor for next listing.
type ProjectRecordsPage struct {
Records []ProjectRecord
Next bool
NextOffset int64
Records []ProjectRecord
Next bool
Cursor uuid.UUID
}

View File

@ -4,6 +4,7 @@
package stripe_test
import (
"fmt"
"testing"
"time"
@ -50,7 +51,7 @@ func TestProjectRecords(t *testing.T) {
assert.Equal(t, stripe.ErrProjectRecordExists, err)
})
page, err := projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end)
page, err := projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end)
require.NoError(t, err)
require.Equal(t, 1, len(page.Records))
@ -59,7 +60,7 @@ func TestProjectRecords(t *testing.T) {
require.NoError(t, err)
})
page, err = projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end)
page, err = projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end)
require.NoError(t, err)
require.Equal(t, 0, len(page.Records))
})
@ -74,8 +75,7 @@ func TestProjectRecordsList(t *testing.T) {
projectRecordsDB := db.StripeCoinPayments().ProjectRecords()
const limit = 5
const recordsLen = limit * 4
const recordsLen = 20
var createProjectRecords []stripe.CreateProjectRecord
for i := 0; i < recordsLen; i++ {
@ -95,37 +95,42 @@ func TestProjectRecordsList(t *testing.T) {
err := projectRecordsDB.Create(ctx, createProjectRecords, start, end)
require.NoError(t, err)
page, err := projectRecordsDB.ListUnapplied(ctx, 0, limit, start, end)
require.NoError(t, err)
for _, limit := range []int{1, 3, 5, 30} {
t.Run(fmt.Sprintf("limit-%d", limit), func(t *testing.T) {
records := []stripe.ProjectRecord{}
records := page.Records
var page stripe.ProjectRecordsPage
for {
page, err = projectRecordsDB.ListUnapplied(ctx, page.Cursor, limit, start, end)
require.NoError(t, err)
for page.Next {
page, err = projectRecordsDB.ListUnapplied(ctx, page.NextOffset, limit, start, end)
require.NoError(t, err)
records = append(records, page.Records...)
}
require.Equal(t, recordsLen, len(records))
assert.False(t, page.Next)
assert.Equal(t, int64(0), page.NextOffset)
for _, record := range page.Records {
for _, createRecord := range createProjectRecords {
if record.ProjectID != createRecord.ProjectID {
continue
records = append(records, page.Records...)
if !page.Next {
break
}
}
assert.NotNil(t, record.ID)
assert.Equal(t, 16, len(record.ID))
assert.Equal(t, createRecord.ProjectID, record.ProjectID)
assert.Equal(t, createRecord.Storage, record.Storage)
assert.Equal(t, createRecord.Egress, record.Egress)
assert.Equal(t, createRecord.Segments, record.Segments)
assert.True(t, start.Equal(record.PeriodStart))
assert.True(t, end.Equal(record.PeriodEnd))
}
require.Equal(t, recordsLen, len(records))
assert.False(t, page.Next)
assert.Equal(t, uuid.UUID{}, page.Cursor)
for _, record := range page.Records {
for _, createRecord := range createProjectRecords {
if record.ProjectID != createRecord.ProjectID {
continue
}
assert.NotNil(t, record.ID)
assert.Equal(t, 16, len(record.ID))
assert.Equal(t, createRecord.ProjectID, record.ProjectID)
assert.Equal(t, createRecord.Storage, record.Storage)
assert.Equal(t, createRecord.Egress, record.Egress)
assert.Equal(t, createRecord.Segments, record.Segments)
assert.True(t, start.Equal(record.PeriodStart))
assert.True(t, end.Equal(record.PeriodEnd))
}
}
})
}
})
}

View File

@ -243,7 +243,7 @@ func (service *Service) InvoiceApplyProjectRecords(ctx context.Context, period t
}
// we are always starting from offset 0 because applyProjectRecords is changing project record state to applied
recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, 0, service.listingLimit, start, end)
recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, service.listingLimit, start, end)
if err != nil {
return Error.Wrap(err)
}

View File

@ -139,7 +139,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) {
end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)
// check if we have project record for each project
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, numberOfProjects, len(recordsPage.Records))
@ -147,7 +147,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) {
require.NoError(t, err)
// verify that we applied all unapplied project records
recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, 0, len(recordsPage.Records))
})
@ -284,7 +284,7 @@ func TestService_ProjectsWithMembers(t *testing.T) {
start := time.Date(period.Year(), period.Month(), 1, 0, 0, 0, 0, time.UTC)
end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, len(projects), len(recordsPage.Records))
})

View File

@ -398,3 +398,13 @@ func (dbc *satelliteDBCollectionTesting) ProductionMigration() *migrate.Migratio
func (dbc *satelliteDBCollectionTesting) TestMigration() *migrate.Migration {
return dbc.getByName("").TestMigration()
}
func withRows(rows tagsql.Rows, err error) func(func(tagsql.Rows) error) error {
return func(callback func(tagsql.Rows) error) error {
if err != nil {
return err
}
err := callback(rows)
return errs.Combine(rows.Err(), rows.Close(), err)
}
}

View File

@ -232,12 +232,6 @@ read one (
where stripecoinpayments_invoice_project_record.period_start = ?
where stripecoinpayments_invoice_project_record.period_end = ?
)
read limitoffset (
select stripecoinpayments_invoice_project_record
where stripecoinpayments_invoice_project_record.period_start = ?
where stripecoinpayments_invoice_project_record.period_end = ?
where stripecoinpayments_invoice_project_record.state = ?
)
// stripecoinpayments_tx_conversion_rate contains information about a conversion-rate that was used in a transaction.
model stripecoinpayments_tx_conversion_rate (

View File

@ -14567,57 +14567,6 @@ func (obj *pgxImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_
}
func (obj *pgxImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?")
var __values []interface{}
__values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value())
__values = append(__values, limit, offset)
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
for {
rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()
for __rows.Next() {
stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{}
err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt)
if err != nil {
return nil, err
}
rows = append(rows, stripecoinpayments_invoice_project_record)
}
err = __rows.Err()
if err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}
}
func (obj *pgxImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context,
stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) (
stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) {
@ -22416,57 +22365,6 @@ func (obj *pgxcockroachImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_Proje
}
func (obj *pgxcockroachImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?")
var __values []interface{}
__values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value())
__values = append(__values, limit, offset)
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
for {
rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()
for __rows.Next() {
stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{}
err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt)
if err != nil {
return nil, err
}
rows = append(rows, stripecoinpayments_invoice_project_record)
}
err = __rows.Err()
if err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}
}
func (obj *pgxcockroachImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context,
stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) (
stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) {
@ -29364,19 +29262,6 @@ func (rx *Rx) Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Des
return tx.Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Desc_LogIndex(ctx, storjscan_payment_to_address, limit, offset)
}
func (rx *Rx) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, stripecoinpayments_invoice_project_record_period_start, stripecoinpayments_invoice_project_record_period_end, stripecoinpayments_invoice_project_record_state, limit, offset)
}
func (rx *Rx) Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context,
bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field,
limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) (
@ -30513,13 +30398,6 @@ type Methods interface {
limit int, offset int64) (
rows []*StorjscanPayment, err error)
Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error)
Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context,
bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field,
limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) (

View File

@ -12,6 +12,7 @@ import (
"github.com/zeebo/errs"
"storj.io/common/uuid"
"storj.io/private/tagsql"
"storj.io/storj/satellite/payments/stripe"
"storj.io/storj/satellite/satellitedb/dbx"
)
@ -129,36 +130,40 @@ func (db *invoiceProjectRecords) Consume(ctx context.Context, id uuid.UUID) (err
}
// ListUnapplied returns project records page with unapplied project records.
func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (_ stripe.ProjectRecordsPage, err error) {
// Cursor is not included into listing results.
func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) {
defer mon.Task()(&ctx)(&err)
var page stripe.ProjectRecordsPage
err = withRows(db.db.QueryContext(ctx, db.db.Rebind(`
SELECT
id, project_id, storage, egress, segments, period_start, period_end, state
FROM
stripecoinpayments_invoice_project_records
WHERE
id > ? AND period_start = ? AND period_end = ? AND state = ?
LIMIT ?
`), cursor, start, end, invoiceProjectRecordStateUnapplied.Int(), limit+1))(func(rows tagsql.Rows) error {
for rows.Next() {
var record stripe.ProjectRecord
err := rows.Scan(&record.ID, &record.ProjectID, &record.Storage, &record.Egress, &record.Segments, &record.PeriodStart, &record.PeriodEnd, &record.State)
if err != nil {
return Error.New("failed to scan stripe invoice project records: %w", err)
}
dbxRecords, err := db.db.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx,
dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodStart(start),
dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodEnd(end),
dbx.StripecoinpaymentsInvoiceProjectRecord_State(invoiceProjectRecordStateUnapplied.Int()),
limit+1,
offset,
)
page.Records = append(page.Records, record)
}
return nil
})
if err != nil {
return stripe.ProjectRecordsPage{}, err
}
if len(dbxRecords) == limit+1 {
if len(page.Records) == limit+1 {
page.Next = true
page.NextOffset = offset + int64(limit)
dbxRecords = dbxRecords[:len(dbxRecords)-1]
}
page.Records = page.Records[:len(page.Records)-1]
for _, dbxRecord := range dbxRecords {
record, err := fromDBXInvoiceProjectRecord(dbxRecord)
if err != nil {
return stripe.ProjectRecordsPage{}, err
}
page.Records = append(page.Records, *record)
page.Cursor = page.Records[len(page.Records)-1].ID
}
return page, nil