satellite/satellitedb: Add GetByEmailWithUnverified to users table

Allows us to handle duplicate emails better.

Change-Id: I266057900725e50d1c47977da307714fd32d9081
This commit is contained in:
Jeremy Wharton 2021-11-18 13:54:39 -05:00 committed by Maximillian von Briesen
parent 9c1129b4c4
commit 984792fd1e
5 changed files with 185 additions and 2 deletions

View File

@ -18,7 +18,9 @@ import (
type Users interface {
// Get is a method for querying user from the database by id.
Get(ctx context.Context, id uuid.UUID) (*User, error)
// GetByEmail is a method for querying user by email from the database.
// GetByEmailWithUnverified is a method for querying users by email from the database.
GetByEmailWithUnverified(ctx context.Context, email string) (*User, []User, error)
// GetByEmail is a method for querying user by verified email from the database.
GetByEmail(ctx context.Context, email string) (*User, error)
// Insert is a method for inserting user into the database.
Insert(ctx context.Context, user *User) (*User, error)

View File

@ -5,6 +5,7 @@ package console_test
import (
"context"
"database/sql"
"testing"
"time"
@ -283,3 +284,47 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
assert.Error(t, err)
})
}
func TestGetUserByEmail(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
usersRepo := db.Console().Users()
email := "test@mail.test"
inactiveUser := console.User{
ID: testrand.UUID(),
FullName: "Inactive User",
Email: email,
PasswordHash: []byte("123a123"),
}
_, err := usersRepo.Insert(ctx, &inactiveUser)
require.NoError(t, err)
_, err = usersRepo.GetByEmail(ctx, email)
require.ErrorIs(t, sql.ErrNoRows, err)
verified, unverified, err := usersRepo.GetByEmailWithUnverified(ctx, email)
require.NoError(t, err)
require.Nil(t, verified)
require.Equal(t, inactiveUser.ID, unverified[0].ID)
activeUser := console.User{
ID: testrand.UUID(),
FullName: "Active User",
Email: email,
Status: console.Active,
PasswordHash: []byte("123a123"),
}
_, err = usersRepo.Insert(ctx, &activeUser)
require.NoError(t, err)
// Required to set the active status.
err = usersRepo.Update(ctx, &activeUser)
require.NoError(t, err)
dbUser, err := usersRepo.GetByEmail(ctx, email)
require.NoError(t, err)
require.Equal(t, activeUser.ID, dbUser.ID)
})
}

View File

@ -312,6 +312,10 @@ create user ( )
update user ( where user.id = ? )
delete user ( where user.id = ? )
read all (
select user
where user.normalized_email = ?
)
read one (
select user
where user.normalized_email = ?

View File

@ -11403,6 +11403,51 @@ func (obj *pgxImpl) Get_Reputation_By_Id(ctx context.Context,
}
func (obj *pgxImpl) All_User_By_NormalizedEmail(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
rows []*User, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT users.id, users.email, users.normalized_email, users.full_name, users.short_name, users.password_hash, users.status, users.partner_id, users.user_agent, users.created_at, users.project_limit, users.project_bandwidth_limit, users.project_storage_limit, users.paid_tier, users.position, users.company_name, users.company_size, users.working_on, users.is_professional, users.employee_count, users.have_sales_contact, users.mfa_enabled, users.mfa_secret_key, users.mfa_recovery_codes, users.signup_promo_code FROM users WHERE users.normalized_email = ?")
var __values []interface{}
__values = append(__values, user_normalized_email.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
for {
rows, err = func() (rows []*User, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()
for __rows.Next() {
user := &User{}
err = __rows.Scan(&user.Id, &user.Email, &user.NormalizedEmail, &user.FullName, &user.ShortName, &user.PasswordHash, &user.Status, &user.PartnerId, &user.UserAgent, &user.CreatedAt, &user.ProjectLimit, &user.ProjectBandwidthLimit, &user.ProjectStorageLimit, &user.PaidTier, &user.Position, &user.CompanyName, &user.CompanySize, &user.WorkingOn, &user.IsProfessional, &user.EmployeeCount, &user.HaveSalesContact, &user.MfaEnabled, &user.MfaSecretKey, &user.MfaRecoveryCodes, &user.SignupPromoCode)
if err != nil {
return nil, err
}
rows = append(rows, user)
}
if err := __rows.Err(); err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}
}
func (obj *pgxImpl) Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
user *User, err error) {
@ -17318,6 +17363,51 @@ func (obj *pgxcockroachImpl) Get_Reputation_By_Id(ctx context.Context,
}
func (obj *pgxcockroachImpl) All_User_By_NormalizedEmail(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
rows []*User, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT users.id, users.email, users.normalized_email, users.full_name, users.short_name, users.password_hash, users.status, users.partner_id, users.user_agent, users.created_at, users.project_limit, users.project_bandwidth_limit, users.project_storage_limit, users.paid_tier, users.position, users.company_name, users.company_size, users.working_on, users.is_professional, users.employee_count, users.have_sales_contact, users.mfa_enabled, users.mfa_secret_key, users.mfa_recovery_codes, users.signup_promo_code FROM users WHERE users.normalized_email = ?")
var __values []interface{}
__values = append(__values, user_normalized_email.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
for {
rows, err = func() (rows []*User, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()
for __rows.Next() {
user := &User{}
err = __rows.Scan(&user.Id, &user.Email, &user.NormalizedEmail, &user.FullName, &user.ShortName, &user.PasswordHash, &user.Status, &user.PartnerId, &user.UserAgent, &user.CreatedAt, &user.ProjectLimit, &user.ProjectBandwidthLimit, &user.ProjectStorageLimit, &user.PaidTier, &user.Position, &user.CompanyName, &user.CompanySize, &user.WorkingOn, &user.IsProfessional, &user.EmployeeCount, &user.HaveSalesContact, &user.MfaEnabled, &user.MfaSecretKey, &user.MfaRecoveryCodes, &user.SignupPromoCode)
if err != nil {
return nil, err
}
rows = append(rows, user)
}
if err := __rows.Err(); err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}
}
func (obj *pgxcockroachImpl) Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
user *User, err error) {
@ -22267,6 +22357,16 @@ func (rx *Rx) All_StoragenodeStorageTally_By_IntervalEndTime_GreaterOrEqual(ctx
return tx.All_StoragenodeStorageTally_By_IntervalEndTime_GreaterOrEqual(ctx, storagenode_storage_tally_interval_end_time_greater_or_equal)
}
func (rx *Rx) All_User_By_NormalizedEmail(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
rows []*User, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.All_User_By_NormalizedEmail(ctx, user_normalized_email)
}
func (rx *Rx) Count_BucketMetainfo_Name_By_ProjectId(ctx context.Context,
bucket_metainfo_project_id BucketMetainfo_ProjectId_Field) (
count int64, err error) {
@ -23618,6 +23718,10 @@ type Methods interface {
storagenode_storage_tally_interval_end_time_greater_or_equal StoragenodeStorageTally_IntervalEndTime_Field) (
rows []*StoragenodeStorageTally, err error)
All_User_By_NormalizedEmail(ctx context.Context,
user_normalized_email User_NormalizedEmail_Field) (
rows []*User, err error)
Count_BucketMetainfo_Name_By_ProjectId(ctx context.Context,
bucket_metainfo_project_id BucketMetainfo_ProjectId_Field) (
count int64, err error)

View File

@ -28,6 +28,7 @@ type users struct {
func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err error) {
defer mon.Task()(&ctx)(&err)
user, err := users.db.Get_User_By_Id(ctx, dbx.User_Id(id[:]))
if err != nil {
return nil, err
}
@ -35,7 +36,34 @@ func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err
return userFromDBX(ctx, user)
}
// GetByEmail is a method for querying user by email from the database.
// GetByEmailWithUnverified is a method for querying users by email from the database.
func (users *users) GetByEmailWithUnverified(ctx context.Context, email string) (verified *console.User, unverified []console.User, err error) {
defer mon.Task()(&ctx)(&err)
usersDbx, err := users.db.All_User_By_NormalizedEmail(ctx, dbx.User_NormalizedEmail(normalizeEmail(email)))
if err != nil {
return nil, nil, err
}
var errors errs.Group
for _, userDbx := range usersDbx {
u, err := userFromDBX(ctx, userDbx)
if err != nil {
errors.Add(err)
continue
}
if u.Status == console.Active {
verified = u
} else {
unverified = append(unverified, *u)
}
}
return verified, unverified, errors.Err()
}
// GetByEmail is a method for querying user by verified email from the database.
func (users *users) GetByEmail(ctx context.Context, email string) (_ *console.User, err error) {
defer mon.Task()(&ctx)(&err)
user, err := users.db.Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx, dbx.User_NormalizedEmail(normalizeEmail(email)))