From 984792fd1e85cf84b6f7509afac80f7a6d85e488 Mon Sep 17 00:00:00 2001 From: Jeremy Wharton Date: Thu, 18 Nov 2021 13:54:39 -0500 Subject: [PATCH] satellite/satellitedb: Add GetByEmailWithUnverified to users table Allows us to handle duplicate emails better. Change-Id: I266057900725e50d1c47977da307714fd32d9081 --- satellite/console/users.go | 4 +- satellite/console/users_test.go | 45 ++++++++ satellite/satellitedb/dbx/satellitedb.dbx | 4 + satellite/satellitedb/dbx/satellitedb.dbx.go | 104 +++++++++++++++++++ satellite/satellitedb/users.go | 30 +++++- 5 files changed, 185 insertions(+), 2 deletions(-) diff --git a/satellite/console/users.go b/satellite/console/users.go index a5f89e394..15b067f18 100644 --- a/satellite/console/users.go +++ b/satellite/console/users.go @@ -18,7 +18,9 @@ import ( type Users interface { // Get is a method for querying user from the database by id. Get(ctx context.Context, id uuid.UUID) (*User, error) - // GetByEmail is a method for querying user by email from the database. + // GetByEmailWithUnverified is a method for querying users by email from the database. + GetByEmailWithUnverified(ctx context.Context, email string) (*User, []User, error) + // GetByEmail is a method for querying user by verified email from the database. GetByEmail(ctx context.Context, email string) (*User, error) // Insert is a method for inserting user into the database. Insert(ctx context.Context, user *User) (*User, error) diff --git a/satellite/console/users_test.go b/satellite/console/users_test.go index 94771acb2..196a4190c 100644 --- a/satellite/console/users_test.go +++ b/satellite/console/users_test.go @@ -5,6 +5,7 @@ package console_test import ( "context" + "database/sql" "testing" "time" @@ -283,3 +284,47 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user assert.Error(t, err) }) } + +func TestGetUserByEmail(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + usersRepo := db.Console().Users() + email := "test@mail.test" + + inactiveUser := console.User{ + ID: testrand.UUID(), + FullName: "Inactive User", + Email: email, + PasswordHash: []byte("123a123"), + } + + _, err := usersRepo.Insert(ctx, &inactiveUser) + require.NoError(t, err) + + _, err = usersRepo.GetByEmail(ctx, email) + require.ErrorIs(t, sql.ErrNoRows, err) + + verified, unverified, err := usersRepo.GetByEmailWithUnverified(ctx, email) + require.NoError(t, err) + require.Nil(t, verified) + require.Equal(t, inactiveUser.ID, unverified[0].ID) + + activeUser := console.User{ + ID: testrand.UUID(), + FullName: "Active User", + Email: email, + Status: console.Active, + PasswordHash: []byte("123a123"), + } + + _, err = usersRepo.Insert(ctx, &activeUser) + require.NoError(t, err) + + // Required to set the active status. + err = usersRepo.Update(ctx, &activeUser) + require.NoError(t, err) + + dbUser, err := usersRepo.GetByEmail(ctx, email) + require.NoError(t, err) + require.Equal(t, activeUser.ID, dbUser.ID) + }) +} diff --git a/satellite/satellitedb/dbx/satellitedb.dbx b/satellite/satellitedb/dbx/satellitedb.dbx index 00f248186..e1727a1b6 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx +++ b/satellite/satellitedb/dbx/satellitedb.dbx @@ -312,6 +312,10 @@ create user ( ) update user ( where user.id = ? ) delete user ( where user.id = ? ) +read all ( + select user + where user.normalized_email = ? +) read one ( select user where user.normalized_email = ? diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 9d6d6aa00..2c791e2a3 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -11403,6 +11403,51 @@ func (obj *pgxImpl) Get_Reputation_By_Id(ctx context.Context, } +func (obj *pgxImpl) All_User_By_NormalizedEmail(ctx context.Context, + user_normalized_email User_NormalizedEmail_Field) ( + rows []*User, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT users.id, users.email, users.normalized_email, users.full_name, users.short_name, users.password_hash, users.status, users.partner_id, users.user_agent, users.created_at, users.project_limit, users.project_bandwidth_limit, users.project_storage_limit, users.paid_tier, users.position, users.company_name, users.company_size, users.working_on, users.is_professional, users.employee_count, users.have_sales_contact, users.mfa_enabled, users.mfa_secret_key, users.mfa_recovery_codes, users.signup_promo_code FROM users WHERE users.normalized_email = ?") + + var __values []interface{} + __values = append(__values, user_normalized_email.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*User, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + user := &User{} + err = __rows.Scan(&user.Id, &user.Email, &user.NormalizedEmail, &user.FullName, &user.ShortName, &user.PasswordHash, &user.Status, &user.PartnerId, &user.UserAgent, &user.CreatedAt, &user.ProjectLimit, &user.ProjectBandwidthLimit, &user.ProjectStorageLimit, &user.PaidTier, &user.Position, &user.CompanyName, &user.CompanySize, &user.WorkingOn, &user.IsProfessional, &user.EmployeeCount, &user.HaveSalesContact, &user.MfaEnabled, &user.MfaSecretKey, &user.MfaRecoveryCodes, &user.SignupPromoCode) + if err != nil { + return nil, err + } + rows = append(rows, user) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxImpl) Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx context.Context, user_normalized_email User_NormalizedEmail_Field) ( user *User, err error) { @@ -17318,6 +17363,51 @@ func (obj *pgxcockroachImpl) Get_Reputation_By_Id(ctx context.Context, } +func (obj *pgxcockroachImpl) All_User_By_NormalizedEmail(ctx context.Context, + user_normalized_email User_NormalizedEmail_Field) ( + rows []*User, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT users.id, users.email, users.normalized_email, users.full_name, users.short_name, users.password_hash, users.status, users.partner_id, users.user_agent, users.created_at, users.project_limit, users.project_bandwidth_limit, users.project_storage_limit, users.paid_tier, users.position, users.company_name, users.company_size, users.working_on, users.is_professional, users.employee_count, users.have_sales_contact, users.mfa_enabled, users.mfa_secret_key, users.mfa_recovery_codes, users.signup_promo_code FROM users WHERE users.normalized_email = ?") + + var __values []interface{} + __values = append(__values, user_normalized_email.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*User, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + user := &User{} + err = __rows.Scan(&user.Id, &user.Email, &user.NormalizedEmail, &user.FullName, &user.ShortName, &user.PasswordHash, &user.Status, &user.PartnerId, &user.UserAgent, &user.CreatedAt, &user.ProjectLimit, &user.ProjectBandwidthLimit, &user.ProjectStorageLimit, &user.PaidTier, &user.Position, &user.CompanyName, &user.CompanySize, &user.WorkingOn, &user.IsProfessional, &user.EmployeeCount, &user.HaveSalesContact, &user.MfaEnabled, &user.MfaSecretKey, &user.MfaRecoveryCodes, &user.SignupPromoCode) + if err != nil { + return nil, err + } + rows = append(rows, user) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxcockroachImpl) Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx context.Context, user_normalized_email User_NormalizedEmail_Field) ( user *User, err error) { @@ -22267,6 +22357,16 @@ func (rx *Rx) All_StoragenodeStorageTally_By_IntervalEndTime_GreaterOrEqual(ctx return tx.All_StoragenodeStorageTally_By_IntervalEndTime_GreaterOrEqual(ctx, storagenode_storage_tally_interval_end_time_greater_or_equal) } +func (rx *Rx) All_User_By_NormalizedEmail(ctx context.Context, + user_normalized_email User_NormalizedEmail_Field) ( + rows []*User, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.All_User_By_NormalizedEmail(ctx, user_normalized_email) +} + func (rx *Rx) Count_BucketMetainfo_Name_By_ProjectId(ctx context.Context, bucket_metainfo_project_id BucketMetainfo_ProjectId_Field) ( count int64, err error) { @@ -23618,6 +23718,10 @@ type Methods interface { storagenode_storage_tally_interval_end_time_greater_or_equal StoragenodeStorageTally_IntervalEndTime_Field) ( rows []*StoragenodeStorageTally, err error) + All_User_By_NormalizedEmail(ctx context.Context, + user_normalized_email User_NormalizedEmail_Field) ( + rows []*User, err error) + Count_BucketMetainfo_Name_By_ProjectId(ctx context.Context, bucket_metainfo_project_id BucketMetainfo_ProjectId_Field) ( count int64, err error) diff --git a/satellite/satellitedb/users.go b/satellite/satellitedb/users.go index bc2538f9c..7a4262ecb 100644 --- a/satellite/satellitedb/users.go +++ b/satellite/satellitedb/users.go @@ -28,6 +28,7 @@ type users struct { func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err error) { defer mon.Task()(&ctx)(&err) user, err := users.db.Get_User_By_Id(ctx, dbx.User_Id(id[:])) + if err != nil { return nil, err } @@ -35,7 +36,34 @@ func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err return userFromDBX(ctx, user) } -// GetByEmail is a method for querying user by email from the database. +// GetByEmailWithUnverified is a method for querying users by email from the database. +func (users *users) GetByEmailWithUnverified(ctx context.Context, email string) (verified *console.User, unverified []console.User, err error) { + defer mon.Task()(&ctx)(&err) + usersDbx, err := users.db.All_User_By_NormalizedEmail(ctx, dbx.User_NormalizedEmail(normalizeEmail(email))) + + if err != nil { + return nil, nil, err + } + + var errors errs.Group + for _, userDbx := range usersDbx { + u, err := userFromDBX(ctx, userDbx) + if err != nil { + errors.Add(err) + continue + } + + if u.Status == console.Active { + verified = u + } else { + unverified = append(unverified, *u) + } + } + + return verified, unverified, errors.Err() +} + +// GetByEmail is a method for querying user by verified email from the database. func (users *users) GetByEmail(ctx context.Context, email string) (_ *console.User, err error) { defer mon.Task()(&ctx)(&err) user, err := users.db.Get_User_By_NormalizedEmail_And_Status_Not_Number(ctx, dbx.User_NormalizedEmail(normalizeEmail(email)))