web/satellite: increment login failed_login_count in sql

This change increments users' failed_login_count in the database layer to avoid potential data race.
It also updates the login_lockout_expiration as well in one operation.

see: https://github.com/storj/storj/issues/4986

Change-Id: I74624f1bee31667b269cb205d74d16e79daabcb6
This commit is contained in:
Wilfred Asomani 2022-09-23 22:23:32 +00:00 committed by Storj Robot
parent d632f23950
commit 903ea38c86
3 changed files with 25 additions and 8 deletions

View File

@ -1161,13 +1161,10 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (response *TokenI
func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User) (err error) { func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
updateRequest := UpdateUserRequest{} var failedLoginPenalty *float64
if user.FailedLoginCount >= s.config.LoginAttemptsWithoutPenalty-1 { if user.FailedLoginCount >= s.config.LoginAttemptsWithoutPenalty-1 {
lockoutDuration := time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute lockoutDuration := time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute
lockoutExpTime := time.Now().Add(lockoutDuration) failedLoginPenalty = &s.config.FailedLoginPenalty
lockoutExpTimePtr := &lockoutExpTime
updateRequest.LoginLockoutExpiration = &lockoutExpTimePtr
address := s.satelliteAddress address := s.satelliteAddress
if !strings.HasSuffix(address, "/") { if !strings.HasSuffix(address, "/") {
@ -1184,10 +1181,8 @@ func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User) (
}, },
) )
} }
user.FailedLoginCount++
updateRequest.FailedLoginCount = &user.FailedLoginCount return s.store.Users().UpdateFailedLoginCountAndExpiration(ctx, failedLoginPenalty, user.ID)
return s.store.Users().Update(ctx, user.ID, updateRequest)
} }
// GetUser returns User by id. // GetUser returns User by id.

View File

@ -25,6 +25,8 @@ type Users interface {
GetUnverifiedNeedingReminder(ctx context.Context, firstReminder, secondReminder, cutoff time.Time) ([]*User, error) GetUnverifiedNeedingReminder(ctx context.Context, firstReminder, secondReminder, cutoff time.Time) ([]*User, error)
// UpdateVerificationReminders increments verification_reminders. // UpdateVerificationReminders increments verification_reminders.
UpdateVerificationReminders(ctx context.Context, id uuid.UUID) error UpdateVerificationReminders(ctx context.Context, id uuid.UUID) error
// UpdateFailedLoginCountAndExpiration increments failed_login_count and sets login_lockout_expiration appropriately.
UpdateFailedLoginCountAndExpiration(ctx context.Context, failedLoginPenalty *float64, id uuid.UUID) error
// GetByEmailWithUnverified is a method for querying users by email from the database. // GetByEmailWithUnverified is a method for querying users by email from the database.
GetByEmailWithUnverified(ctx context.Context, email string) (*User, []User, error) GetByEmailWithUnverified(ctx context.Context, email string) (*User, []User, error)
// GetByEmail is a method for querying user by verified email from the database. // GetByEmail is a method for querying user by verified email from the database.

View File

@ -25,6 +25,26 @@ type users struct {
db *satelliteDB db *satelliteDB
} }
// UpdateFailedLoginCountAndExpiration increments failed_login_count and sets login_lockout_expiration appropriately.
func (users *users) UpdateFailedLoginCountAndExpiration(ctx context.Context, failedLoginPenalty *float64, id uuid.UUID) (err error) {
if failedLoginPenalty != nil {
// failed_login_count exceeded config.FailedLoginPenalty
_, err = users.db.ExecContext(ctx, users.db.Rebind(`
UPDATE users
SET failed_login_count = COALESCE(failed_login_count, 0) + 1,
login_lockout_expiration = CURRENT_TIMESTAMP + POWER(?, failed_login_count-1) * INTERVAL '1 minute'
WHERE id = ?
`), failedLoginPenalty, id.Bytes())
} else {
_, err = users.db.ExecContext(ctx, users.db.Rebind(`
UPDATE users
SET failed_login_count = COALESCE(failed_login_count, 0) + 1
WHERE id = ?
`), id.Bytes())
}
return
}
// Get is a method for querying user from the database by id. // Get is a method for querying user from the database by id.
func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err error) { func (users *users) Get(ctx context.Context, id uuid.UUID) (_ *console.User, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)