add db implementation for user_credits table (#2169)
* add dbx queries * add migration file * start service * Add TotalReferredCountByUserId and availableCreditsByUserID * implement UserCredits interface and UserCredit struct type * add UserCredits into consoledb * add setupData helper function * add test for update * update lock file * fix lint error * add invalidUserCredits tests * rename method * adds comments * add checks for erros in setupData * change update method to only execute one query per request * rename vairable * should return a signal from Update method if the charge is not fully complete * changes for readability * prevent sql injection * rename * improve readability
This commit is contained in:
parent
c481e071b2
commit
954ca3c6ee
@ -27,6 +27,8 @@ type DB interface {
|
||||
ResetPasswordTokens() ResetPasswordTokens
|
||||
// UsageRollups is a getter for UsageRollups repository
|
||||
UsageRollups() UsageRollups
|
||||
// UserCredits is a getter for UserCredits repository
|
||||
UserCredits() UserCredits
|
||||
// UserPayments is a getter for UserPayments repository
|
||||
UserPayments() UserPayments
|
||||
// ProjectPayments is a getter for ProjectPayments repository
|
||||
|
31
satellite/console/usercredits.go
Normal file
31
satellite/console/usercredits.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package console
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
)
|
||||
|
||||
// UserCredits holds information to interact with database
|
||||
type UserCredits interface {
|
||||
TotalReferredCount(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
GetAvailableCredits(ctx context.Context, userID uuid.UUID, expirationEndDate time.Time) ([]UserCredit, error)
|
||||
Create(ctx context.Context, userCredit UserCredit) (*UserCredit, error)
|
||||
UpdateAvailableCredits(ctx context.Context, creditsToCharge int, id uuid.UUID, billingStartDate time.Time) (remainingCharge int, err error)
|
||||
}
|
||||
|
||||
// UserCredit holds information about an user's credit
|
||||
type UserCredit struct {
|
||||
ID int
|
||||
UserID uuid.UUID
|
||||
OfferID int
|
||||
ReferredBy uuid.UUID
|
||||
CreditsEarnedInCents int
|
||||
CreditsUsedInCents int
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
@ -61,6 +61,11 @@ func (db *ConsoleDB) UsageRollups() console.UsageRollups {
|
||||
return &usagerollups{db.db}
|
||||
}
|
||||
|
||||
// UserCredits is a getter for console.UserCredits repository
|
||||
func (db *ConsoleDB) UserCredits() console.UserCredits {
|
||||
return &usercredits{db: db.db}
|
||||
}
|
||||
|
||||
// UserPayments is a getter for console.UserPayments repository
|
||||
func (db *ConsoleDB) UserPayments() console.UserPayments {
|
||||
return &userpayments{db.methods}
|
||||
|
@ -457,6 +457,43 @@ func (m *lockedUsageRollups) GetProjectTotal(ctx context.Context, projectID uuid
|
||||
return m.db.GetProjectTotal(ctx, projectID, since, before)
|
||||
}
|
||||
|
||||
// UserCredits is a getter for UserCredits repository
|
||||
func (m *lockedConsole) UserCredits() console.UserCredits {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return &lockedUserCredits{m.Locker, m.db.UserCredits()}
|
||||
}
|
||||
|
||||
// lockedUserCredits implements locking wrapper for console.UserCredits
|
||||
type lockedUserCredits struct {
|
||||
sync.Locker
|
||||
db console.UserCredits
|
||||
}
|
||||
|
||||
func (m *lockedUserCredits) Create(ctx context.Context, userCredit console.UserCredit) (*console.UserCredit, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.Create(ctx, userCredit)
|
||||
}
|
||||
|
||||
func (m *lockedUserCredits) GetAvailableCredits(ctx context.Context, userID uuid.UUID, expirationEndDate time.Time) ([]console.UserCredit, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.GetAvailableCredits(ctx, userID, expirationEndDate)
|
||||
}
|
||||
|
||||
func (m *lockedUserCredits) TotalReferredCount(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.TotalReferredCount(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *lockedUserCredits) UpdateAvailableCredits(ctx context.Context, creditsToCharge int, id uuid.UUID, billingStartDate time.Time) (remainingCharge int, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db.UpdateAvailableCredits(ctx, creditsToCharge, id, billingStartDate)
|
||||
}
|
||||
|
||||
// UserPayments is a getter for UserPayments repository
|
||||
func (m *lockedConsole) UserPayments() console.UserPayments {
|
||||
m.Lock()
|
||||
|
186
satellite/satellitedb/usercredits.go
Normal file
186
satellite/satellitedb/usercredits.go
Normal file
@ -0,0 +1,186 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/storj/satellite/console"
|
||||
dbx "storj.io/storj/satellite/satellitedb/dbx"
|
||||
)
|
||||
|
||||
type usercredits struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
|
||||
// TotalReferredCount returns the total amount of referral a user has made based on user id
|
||||
func (c *usercredits) TotalReferredCount(ctx context.Context, id uuid.UUID) (int64, error) {
|
||||
totalReferred, err := c.db.Count_UserCredit_By_ReferredBy(ctx, dbx.UserCredit_ReferredBy(id[:]))
|
||||
if err != nil {
|
||||
return totalReferred, errs.Wrap(err)
|
||||
}
|
||||
|
||||
return totalReferred, nil
|
||||
}
|
||||
|
||||
// GetAvailableCredits returns all records of user credit that are not expired or used
|
||||
func (c *usercredits) GetAvailableCredits(ctx context.Context, referrerID uuid.UUID, expirationEndDate time.Time) ([]console.UserCredit, error) {
|
||||
availableCredits, err := c.db.All_UserCredit_By_UserId_And_ExpiresAt_Greater_And_CreditsUsedInCents_Less_CreditsEarnedInCents_OrderBy_Asc_ExpiresAt(ctx,
|
||||
dbx.UserCredit_UserId(referrerID[:]),
|
||||
dbx.UserCredit_ExpiresAt(expirationEndDate),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
return userCreditsFromDBX(availableCredits)
|
||||
}
|
||||
|
||||
// Create insert a new record of user credit
|
||||
func (c *usercredits) Create(ctx context.Context, userCredit console.UserCredit) (*console.UserCredit, error) {
|
||||
credit, err := c.db.Create_UserCredit(ctx,
|
||||
dbx.UserCredit_UserId(userCredit.UserID[:]),
|
||||
dbx.UserCredit_OfferId(userCredit.OfferID),
|
||||
dbx.UserCredit_CreditsEarnedInCents(userCredit.CreditsEarnedInCents),
|
||||
dbx.UserCredit_ExpiresAt(userCredit.ExpiresAt),
|
||||
dbx.UserCredit_Create_Fields{
|
||||
ReferredBy: dbx.UserCredit_ReferredBy(userCredit.ReferredBy[:]),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
return convertDBCredit(credit)
|
||||
}
|
||||
|
||||
// UpdateAvailableCredits updates user's available credits based on their spending and the time of their spending
|
||||
func (c *usercredits) UpdateAvailableCredits(ctx context.Context, creditsToCharge int, id uuid.UUID, expirationEndDate time.Time) (remainingCharge int, err error) {
|
||||
tx, err := c.db.Open(ctx)
|
||||
if err != nil {
|
||||
return creditsToCharge, errs.Wrap(err)
|
||||
}
|
||||
|
||||
availableCredits, err := tx.All_UserCredit_By_UserId_And_ExpiresAt_Greater_And_CreditsUsedInCents_Less_CreditsEarnedInCents_OrderBy_Asc_ExpiresAt(ctx,
|
||||
dbx.UserCredit_UserId(id[:]),
|
||||
dbx.UserCredit_ExpiresAt(expirationEndDate),
|
||||
)
|
||||
if err != nil {
|
||||
return creditsToCharge, errs.Wrap(errs.Combine(err, tx.Rollback()))
|
||||
}
|
||||
if len(availableCredits) == 0 {
|
||||
return creditsToCharge, errs.Combine(errs.New("No available credits"), tx.Commit())
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(availableCredits)*2)
|
||||
rowIds := make([]interface{}, len(availableCredits))
|
||||
|
||||
remainingCharge = creditsToCharge
|
||||
for i, credit := range availableCredits {
|
||||
if remainingCharge == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
creditsForUpdate := credit.CreditsEarnedInCents - credit.CreditsUsedInCents
|
||||
|
||||
if remainingCharge < creditsForUpdate {
|
||||
creditsForUpdate = remainingCharge
|
||||
}
|
||||
|
||||
values[i%2] = credit.Id
|
||||
values[(i%2 + 1)] = creditsForUpdate
|
||||
rowIds[i] = credit.Id
|
||||
|
||||
remainingCharge -= creditsForUpdate
|
||||
}
|
||||
|
||||
values = append(values, rowIds...)
|
||||
|
||||
var statement string
|
||||
switch t := c.db.Driver().(type) {
|
||||
case *sqlite3.SQLiteDriver:
|
||||
statement = generateQuery(len(availableCredits), false)
|
||||
case *pq.Driver:
|
||||
statement = generateQuery(len(availableCredits), true)
|
||||
default:
|
||||
return creditsToCharge, errs.New("Unsupported database %t", t)
|
||||
}
|
||||
|
||||
_, err = tx.Tx.ExecContext(ctx, c.db.Rebind(`UPDATE user_credits SET
|
||||
credits_used_in_cents = CASE `+statement), values...)
|
||||
if err != nil {
|
||||
return creditsToCharge, errs.Wrap(errs.Combine(err, tx.Rollback()))
|
||||
}
|
||||
return remainingCharge, errs.Wrap(tx.Commit())
|
||||
}
|
||||
|
||||
func generateQuery(totalRows int, toInt bool) (query string) {
|
||||
whereClause := `WHERE id IN (`
|
||||
condition := `WHEN id=? THEN ? `
|
||||
if toInt {
|
||||
condition = `WHEN id=? THEN ?::int `
|
||||
}
|
||||
|
||||
for i := 0; i < totalRows; i++ {
|
||||
query += condition
|
||||
|
||||
if i == totalRows-1 {
|
||||
query += ` END ` + whereClause + ` ?);`
|
||||
break
|
||||
}
|
||||
whereClause += `?, `
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func userCreditsFromDBX(userCreditsDBX []*dbx.UserCredit) ([]console.UserCredit, error) {
|
||||
var userCredits []console.UserCredit
|
||||
errList := new(errs.Group)
|
||||
|
||||
for _, credit := range userCreditsDBX {
|
||||
|
||||
uc, err := convertDBCredit(credit)
|
||||
if err != nil {
|
||||
errList.Add(err)
|
||||
continue
|
||||
}
|
||||
userCredits = append(userCredits, *uc)
|
||||
}
|
||||
|
||||
return userCredits, errList.Err()
|
||||
}
|
||||
|
||||
func convertDBCredit(userCreditDBX *dbx.UserCredit) (*console.UserCredit, error) {
|
||||
if userCreditDBX == nil {
|
||||
return nil, errs.New("userCreditDBX parameter is nil")
|
||||
}
|
||||
|
||||
userID, err := bytesToUUID(userCreditDBX.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
referredByID, err := bytesToUUID(userCreditDBX.ReferredBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &console.UserCredit{
|
||||
ID: userCreditDBX.Id,
|
||||
UserID: userID,
|
||||
OfferID: userCreditDBX.OfferId,
|
||||
ReferredBy: referredByID,
|
||||
CreditsEarnedInCents: userCreditDBX.CreditsEarnedInCents,
|
||||
CreditsUsedInCents: userCreditDBX.CreditsUsedInCents,
|
||||
ExpiresAt: userCreditDBX.ExpiresAt,
|
||||
CreatedAt: userCreditDBX.CreatedAt,
|
||||
}, nil
|
||||
}
|
213
satellite/satellitedb/usercredits_test.go
Normal file
213
satellite/satellitedb/usercredits_test.go
Normal file
@ -0,0 +1,213 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package satellitedb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/skyrings/skyring-common/tools/uuid"
|
||||
|
||||
"storj.io/storj/satellite/marketing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/storj/satellite/console"
|
||||
|
||||
"storj.io/storj/internal/testcontext"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||||
)
|
||||
|
||||
func TestUsercredits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
satellitedbtest.Run(t, func(t *testing.T, db satellite.DB) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
consoleDB := db.Console()
|
||||
|
||||
user, referrer, offer := setupData(ctx, t, db)
|
||||
randomID, err := uuid.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
// test foreign key constraint for inserting a new user credit entry with randomID
|
||||
var invalidUserCredits = []console.UserCredit{
|
||||
{
|
||||
UserID: *randomID,
|
||||
OfferID: offer.ID,
|
||||
ReferredBy: referrer.ID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 1, 0),
|
||||
},
|
||||
{
|
||||
UserID: user.ID,
|
||||
OfferID: 10,
|
||||
ReferredBy: referrer.ID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 1, 0),
|
||||
},
|
||||
{
|
||||
UserID: user.ID,
|
||||
OfferID: offer.ID,
|
||||
ReferredBy: *randomID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 1, 0),
|
||||
},
|
||||
}
|
||||
|
||||
for _, ivc := range invalidUserCredits {
|
||||
_, err := consoleDB.UserCredits().Create(ctx, ivc)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
type result struct {
|
||||
remainingCharge int
|
||||
availableCredits int
|
||||
hasErr bool
|
||||
}
|
||||
|
||||
var validUserCredits = []struct {
|
||||
userCredit console.UserCredit
|
||||
chargedCredits int
|
||||
expected result
|
||||
}{
|
||||
{
|
||||
userCredit: console.UserCredit{
|
||||
UserID: user.ID,
|
||||
OfferID: offer.ID,
|
||||
ReferredBy: referrer.ID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 1, 0),
|
||||
},
|
||||
chargedCredits: 120,
|
||||
expected: result{
|
||||
remainingCharge: 20,
|
||||
availableCredits: 0,
|
||||
hasErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
// simulate a credit that's already expired
|
||||
userCredit: console.UserCredit{
|
||||
UserID: user.ID,
|
||||
OfferID: offer.ID,
|
||||
ReferredBy: referrer.ID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 0, -5),
|
||||
},
|
||||
chargedCredits: 60,
|
||||
expected: result{
|
||||
remainingCharge: 60,
|
||||
availableCredits: 0,
|
||||
hasErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
// simulate a credit that's not expired
|
||||
userCredit: console.UserCredit{
|
||||
UserID: user.ID,
|
||||
OfferID: offer.ID,
|
||||
ReferredBy: referrer.ID,
|
||||
CreditsEarnedInCents: 100,
|
||||
ExpiresAt: time.Now().UTC().AddDate(0, 0, 5),
|
||||
},
|
||||
chargedCredits: 80,
|
||||
expected: result{
|
||||
remainingCharge: 0,
|
||||
availableCredits: 20,
|
||||
hasErr: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, vc := range validUserCredits {
|
||||
_, err = consoleDB.UserCredits().Create(ctx, vc.userCredit)
|
||||
require.NoError(t, err)
|
||||
|
||||
{
|
||||
referredCount, err := consoleDB.UserCredits().TotalReferredCount(ctx, vc.userCredit.ReferredBy)
|
||||
if err != nil {
|
||||
require.True(t, uuid.Equal(*randomID, vc.userCredit.ReferredBy))
|
||||
continue
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(i+1), referredCount)
|
||||
}
|
||||
|
||||
{
|
||||
remainingCharge, err := consoleDB.UserCredits().UpdateAvailableCredits(ctx, vc.chargedCredits, vc.userCredit.UserID, time.Now().UTC())
|
||||
if vc.expected.hasErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, vc.expected.remainingCharge, remainingCharge)
|
||||
}
|
||||
|
||||
{
|
||||
availableCredits, err := consoleDB.UserCredits().GetAvailableCredits(ctx, vc.userCredit.UserID, time.Now().UTC())
|
||||
require.NoError(t, err)
|
||||
var sum int
|
||||
for i := range availableCredits {
|
||||
sum += availableCredits[i].CreditsEarnedInCents - availableCredits[i].CreditsUsedInCents
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vc.expected.availableCredits, sum)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func setupData(ctx context.Context, t *testing.T, db satellite.DB) (user *console.User, referrer *console.User, offer *marketing.Offer) {
|
||||
consoleDB := db.Console()
|
||||
marketingDB := db.Marketing()
|
||||
// create user
|
||||
var userPassHash [8]byte
|
||||
_, err := rand.Read(userPassHash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var referrerPassHash [8]byte
|
||||
_, err = rand.Read(referrerPassHash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// create an user
|
||||
user, err = consoleDB.Users().Insert(ctx, &console.User{
|
||||
FullName: "John Doe",
|
||||
Email: "john@mail.test",
|
||||
PasswordHash: userPassHash[:],
|
||||
Status: console.Active,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
//create an user as referrer
|
||||
referrer, err = consoleDB.Users().Insert(ctx, &console.User{
|
||||
FullName: "referrer",
|
||||
Email: "referrer@mail.test",
|
||||
PasswordHash: referrerPassHash[:],
|
||||
Status: console.Active,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create offer
|
||||
offer, err = marketingDB.Offers().Create(ctx, &marketing.NewOffer{
|
||||
Name: "test",
|
||||
Description: "test offer 1",
|
||||
AwardCreditInCents: 100,
|
||||
InviteeCreditInCents: 50,
|
||||
AwardCreditDurationDays: 60,
|
||||
InviteeCreditDurationDays: 30,
|
||||
RedeemableCap: 50,
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour * 1),
|
||||
Status: marketing.Active,
|
||||
Type: marketing.Referral,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return user, referrer, offer
|
||||
}
|
Loading…
Reference in New Issue
Block a user