From 903ea38c86e9c2180bd8277ad213cb23d7ce2ca0 Mon Sep 17 00:00:00 2001 From: Wilfred Asomani Date: Fri, 23 Sep 2022 22:23:32 +0000 Subject: [PATCH] web/satellite: increment login failed_login_count in sql This change increments users' failed_login_count in the database layer to avoid potential data race. It also updates the login_lockout_expiration as well in one operation. see: https://github.com/storj/storj/issues/4986 Change-Id: I74624f1bee31667b269cb205d74d16e79daabcb6 --- satellite/console/service.go | 11 +++-------- satellite/console/users.go | 2 ++ satellite/satellitedb/users.go | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/satellite/console/service.go b/satellite/console/service.go index cbf8f4ef6..ab3aedb0e 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -1161,13 +1161,10 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (response *TokenI func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User) (err error) { defer mon.Task()(&ctx)(&err) - updateRequest := UpdateUserRequest{} + var failedLoginPenalty *float64 if user.FailedLoginCount >= s.config.LoginAttemptsWithoutPenalty-1 { lockoutDuration := time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute - lockoutExpTime := time.Now().Add(lockoutDuration) - lockoutExpTimePtr := &lockoutExpTime - - updateRequest.LoginLockoutExpiration = &lockoutExpTimePtr + failedLoginPenalty = &s.config.FailedLoginPenalty address := s.satelliteAddress if !strings.HasSuffix(address, "/") { @@ -1184,10 +1181,8 @@ func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User) ( }, ) } - user.FailedLoginCount++ - updateRequest.FailedLoginCount = &user.FailedLoginCount - return s.store.Users().Update(ctx, user.ID, updateRequest) + return s.store.Users().UpdateFailedLoginCountAndExpiration(ctx, failedLoginPenalty, user.ID) } // GetUser returns User by id. diff --git a/satellite/console/users.go b/satellite/console/users.go index 8914a6f59..7af398837 100644 --- a/satellite/console/users.go +++ b/satellite/console/users.go @@ -25,6 +25,8 @@ type Users interface { GetUnverifiedNeedingReminder(ctx context.Context, firstReminder, secondReminder, cutoff time.Time) ([]*User, error) // UpdateVerificationReminders increments verification_reminders. UpdateVerificationReminders(ctx context.Context, id uuid.UUID) error + // UpdateFailedLoginCountAndExpiration increments failed_login_count and sets login_lockout_expiration appropriately. + UpdateFailedLoginCountAndExpiration(ctx context.Context, failedLoginPenalty *float64, id uuid.UUID) error // GetByEmailWithUnverified is a method for querying users by email from the database. GetByEmailWithUnverified(ctx context.Context, email string) (*User, []User, error) // GetByEmail is a method for querying user by verified email from the database. diff --git a/satellite/satellitedb/users.go b/satellite/satellitedb/users.go index af2fec588..738ef392c 100644 --- a/satellite/satellitedb/users.go +++ b/satellite/satellitedb/users.go @@ -25,6 +25,26 @@ type users struct { db *satelliteDB } +// UpdateFailedLoginCountAndExpiration increments failed_login_count and sets login_lockout_expiration appropriately. +func (users *users) UpdateFailedLoginCountAndExpiration(ctx context.Context, failedLoginPenalty *float64, id uuid.UUID) (err error) { + if failedLoginPenalty != nil { + // failed_login_count exceeded config.FailedLoginPenalty + _, err = users.db.ExecContext(ctx, users.db.Rebind(` + UPDATE users + SET failed_login_count = COALESCE(failed_login_count, 0) + 1, + login_lockout_expiration = CURRENT_TIMESTAMP + POWER(?, failed_login_count-1) * INTERVAL '1 minute' + WHERE id = ? + `), failedLoginPenalty, id.Bytes()) + } else { + _, err = users.db.ExecContext(ctx, users.db.Rebind(` + UPDATE users + SET failed_login_count = COALESCE(failed_login_count, 0) + 1 + WHERE id = ? + `), id.Bytes()) + } + return +} + // Get is a method for querying user from the database by id. func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err error) { defer mon.Task()(&ctx)(&err)