satellite/console: use new type UpdateUserRequest as arg to db users.Update
The users.Update method in the satellitedb package takes a console.User as an argument. It reads some of the fields on this struct and assigns the value to dbx.User_Update_Fields. However, you cannot optionally update only some of the fields. They all will always be updated. This means that if you only want to update FullName, you still need to read the user info from the DB to avoid updating the rest of the fields to zero. This is not good because concurrent updates can overwrite each other. This change introduces a new struct type, UpdateUserRequest, which contains pointers for all the fields that are updated by satellite db users.Update. Now the update method will check if a field is nil before assigning the value to be updated in the db, so you only need to set the field you want updated. For nullable columns, the respective field is a double pointer. This allows us to update a column to NULL if the outer pointer is not nil, but the inner pointer is. Change-Id: I27f842d283c2711e24d51dcab622e57eeb9157f1
This commit is contained in:
parent
3ae325462c
commit
240b70b828
@ -104,8 +104,9 @@ func (server *Server) addUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Set User Status to be activated, as we manually created it
|
||||
newUser.Status = console.Active
|
||||
newUser.PasswordHash = nil
|
||||
err = server.db.Console().Users().Update(ctx, newUser)
|
||||
err = server.db.Console().Users().Update(ctx, userID, console.UpdateUserRequest{
|
||||
Status: &newUser.Status,
|
||||
})
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to activate user",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
@ -240,32 +241,32 @@ func (server *Server) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
updateRequest := console.UpdateUserRequest{}
|
||||
|
||||
if input.FullName != "" {
|
||||
user.FullName = input.FullName
|
||||
updateRequest.FullName = &input.FullName
|
||||
}
|
||||
if input.ShortName != "" {
|
||||
user.ShortName = input.ShortName
|
||||
shortNamePtr := &input.ShortName
|
||||
updateRequest.ShortName = &shortNamePtr
|
||||
}
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
}
|
||||
if !input.PartnerID.IsZero() {
|
||||
user.PartnerID = input.PartnerID
|
||||
updateRequest.Email = &input.Email
|
||||
}
|
||||
if len(input.PasswordHash) > 0 {
|
||||
user.PasswordHash = input.PasswordHash
|
||||
updateRequest.PasswordHash = input.PasswordHash
|
||||
}
|
||||
if input.ProjectLimit > 0 {
|
||||
user.ProjectLimit = input.ProjectLimit
|
||||
updateRequest.ProjectLimit = &input.ProjectLimit
|
||||
}
|
||||
if input.ProjectStorageLimit > 0 {
|
||||
user.ProjectStorageLimit = input.ProjectStorageLimit
|
||||
updateRequest.ProjectStorageLimit = &input.ProjectStorageLimit
|
||||
}
|
||||
if input.ProjectBandwidthLimit > 0 {
|
||||
user.ProjectBandwidthLimit = input.ProjectBandwidthLimit
|
||||
updateRequest.ProjectBandwidthLimit = &input.ProjectBandwidthLimit
|
||||
}
|
||||
if input.ProjectSegmentLimit > 0 {
|
||||
user.ProjectSegmentLimit = input.ProjectSegmentLimit
|
||||
updateRequest.ProjectSegmentLimit = &input.ProjectSegmentLimit
|
||||
}
|
||||
if input.PaidTierStr != "" {
|
||||
status, err := strconv.ParseBool(input.PaidTierStr)
|
||||
@ -275,10 +276,10 @@ func (server *Server) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user.PaidTier = status
|
||||
updateRequest.PaidTier = &status
|
||||
}
|
||||
|
||||
err = server.db.Console().Users().Update(ctx, user)
|
||||
err = server.db.Console().Users().Update(ctx, user.ID, updateRequest)
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to update user",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
@ -310,9 +311,14 @@ func (server *Server) disableUserMFA(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
user.MFAEnabled = false
|
||||
user.MFASecretKey = ""
|
||||
user.MFARecoveryCodes = nil
|
||||
mfaSecretKeyPtr := &user.MFASecretKey
|
||||
var mfaRecoveryCodes []string
|
||||
|
||||
err = server.db.Console().Users().Update(ctx, user)
|
||||
err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
|
||||
MFAEnabled: &user.MFAEnabled,
|
||||
MFASecretKey: &mfaSecretKeyPtr,
|
||||
MFARecoveryCodes: &mfaRecoveryCodes,
|
||||
})
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to disable mfa",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
@ -402,15 +408,17 @@ func (server *Server) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := &console.User{
|
||||
ID: user.ID,
|
||||
FullName: "",
|
||||
ShortName: "",
|
||||
Email: fmt.Sprintf("deactivated+%s@storj.io", user.ID.String()),
|
||||
Status: console.Deleted,
|
||||
}
|
||||
emptyName := ""
|
||||
emptyNamePtr := &emptyName
|
||||
deactivatedEmail := fmt.Sprintf("deactivated+%s@storj.io", user.ID.String())
|
||||
status := console.Deleted
|
||||
|
||||
err = server.db.Console().Users().Update(ctx, userInfo)
|
||||
err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
|
||||
FullName: &emptyName,
|
||||
ShortName: &emptyNamePtr,
|
||||
Email: &deactivatedEmail,
|
||||
Status: &status,
|
||||
})
|
||||
if err != nil {
|
||||
sendJSONError(w, "unable to delete user",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
|
@ -233,7 +233,14 @@ func TestDisableMFA(t *testing.T) {
|
||||
user.MFAEnabled = true
|
||||
user.MFASecretKey = "randomtext"
|
||||
user.MFARecoveryCodes = []string{"0123456789"}
|
||||
err = planet.Satellites[0].DB.Console().Users().Update(ctx, user)
|
||||
|
||||
secretKeyPtr := &user.MFASecretKey
|
||||
|
||||
err = planet.Satellites[0].DB.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
|
||||
MFAEnabled: &user.MFAEnabled,
|
||||
MFASecretKey: &secretKeyPtr,
|
||||
MFARecoveryCodes: &user.MFARecoveryCodes,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure MFA is enabled.
|
||||
|
@ -695,7 +695,9 @@ func TestRegistrationEmail(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
user.Status = console.Active
|
||||
require.NoError(t, sat.DB.Console().Users().Update(ctx, user))
|
||||
require.NoError(t, sat.DB.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
|
||||
Status: &user.Status,
|
||||
}))
|
||||
|
||||
newUserID = register()
|
||||
require.Equal(t, userID, newUserID)
|
||||
@ -740,7 +742,9 @@ func TestResendActivationEmail(t *testing.T) {
|
||||
|
||||
// Expect activation e-mail to be sent when using unverified e-mail address.
|
||||
user.Status = console.Inactive
|
||||
require.NoError(t, usersRepo.Update(ctx, user))
|
||||
require.NoError(t, usersRepo.Update(ctx, user.ID, console.UpdateUserRequest{
|
||||
Status: &user.Status,
|
||||
}))
|
||||
|
||||
resendEmail()
|
||||
body, err = sender.Data.Get(ctx)
|
||||
|
@ -77,7 +77,9 @@ func TestEmailChoreUpdatesVerificationReminders(t *testing.T) {
|
||||
require.Zero(t, user3.VerificationReminders)
|
||||
|
||||
user1.Status = 1
|
||||
err = users.Update(ctx, user1)
|
||||
err = users.Update(ctx, user1.ID, console.UpdateUserRequest{
|
||||
Status: &user1.Status,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chore.Loop.TriggerWait()
|
||||
@ -95,7 +97,9 @@ func TestEmailChoreUpdatesVerificationReminders(t *testing.T) {
|
||||
require.Equal(t, 1, user3.VerificationReminders)
|
||||
|
||||
user2.Status = 1
|
||||
err = users.Update(ctx, user2)
|
||||
err = users.Update(ctx, user2.ID, console.UpdateUserRequest{
|
||||
Status: &user2.Status,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chore.Loop.TriggerWait()
|
||||
|
@ -100,7 +100,9 @@ func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Tim
|
||||
}
|
||||
|
||||
user.MFAEnabled = true
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
MFAEnabled: &user.MFAEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -151,7 +153,14 @@ func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Ti
|
||||
user.MFAEnabled = false
|
||||
user.MFASecretKey = ""
|
||||
user.MFARecoveryCodes = nil
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
|
||||
secretKeyPtr := &user.MFASecretKey
|
||||
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
MFAEnabled: &user.MFAEnabled,
|
||||
MFASecretKey: &secretKeyPtr,
|
||||
MFARecoveryCodes: &user.MFARecoveryCodes,
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -194,7 +203,11 @@ func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error)
|
||||
}
|
||||
|
||||
user.MFASecretKey = key
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
mfaSecretKeyPtr := &user.MFASecretKey
|
||||
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
MFASecretKey: &mfaSecretKeyPtr,
|
||||
})
|
||||
if err != nil {
|
||||
return "", Error.Wrap(err)
|
||||
}
|
||||
@ -223,9 +236,10 @@ func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, er
|
||||
}
|
||||
codes[i] = code
|
||||
}
|
||||
user.MFARecoveryCodes = codes
|
||||
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
MFARecoveryCodes: &codes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
@ -830,8 +830,10 @@ func (s *Service) ActivateAccount(ctx context.Context, activationToken string) (
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
user.Status = Active
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
status := Active
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
Status: &status,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
@ -898,16 +900,22 @@ func (s *Service) ResetPassword(ctx context.Context, resetPasswordToken, passwor
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
user.PasswordHash = hash
|
||||
if user.FailedLoginCount != 0 {
|
||||
user.FailedLoginCount = 0
|
||||
user.LoginLockoutExpiration = time.Time{}
|
||||
updateRequest := UpdateUserRequest{
|
||||
PasswordHash: hash,
|
||||
}
|
||||
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
if user.FailedLoginCount != 0 {
|
||||
resetFailedLoginCount := 0
|
||||
resetLoginLockoutExpirationPtr := &time.Time{}
|
||||
updateRequest.FailedLoginCount = &resetFailedLoginCount
|
||||
updateRequest.LoginLockoutExpiration = &resetLoginLockoutExpirationPtr
|
||||
}
|
||||
|
||||
err = s.store.Users().Update(ctx, user.ID, updateRequest)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
s.auditLog(ctx, "password reset", &user.ID, user.Email)
|
||||
|
||||
if err = s.store.ResetPasswordTokens().Delete(ctx, token.Secret); err != nil {
|
||||
@ -954,7 +962,7 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
|
||||
|
||||
lockoutExpDate := now.Add(time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute)
|
||||
handleLockAccount := func() error {
|
||||
err = s.UpdateUsersFailedLoginState(ctx, user, lockoutExpDate)
|
||||
err = s.UpdateUsersFailedLoginState(ctx, user, &lockoutExpDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1014,7 +1022,9 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
|
||||
|
||||
user.MFARecoveryCodes = append(user.MFARecoveryCodes[:codeIndex], user.MFARecoveryCodes[codeIndex+1:]...)
|
||||
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
MFARecoveryCodes: &user.MFARecoveryCodes,
|
||||
})
|
||||
if err != nil {
|
||||
return consoleauth.Token{}, err
|
||||
}
|
||||
@ -1045,8 +1055,11 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
|
||||
|
||||
if user.FailedLoginCount != 0 {
|
||||
user.FailedLoginCount = 0
|
||||
user.LoginLockoutExpiration = time.Time{}
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
loginLockoutExpirationPtr := &time.Time{}
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
FailedLoginCount: &user.FailedLoginCount,
|
||||
LoginLockoutExpiration: &loginLockoutExpirationPtr,
|
||||
})
|
||||
if err != nil {
|
||||
return consoleauth.Token{}, err
|
||||
}
|
||||
@ -1063,13 +1076,15 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
|
||||
}
|
||||
|
||||
// UpdateUsersFailedLoginState updates User's failed login state.
|
||||
func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User, lockoutExpDate time.Time) error {
|
||||
func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User, lockoutExpDate *time.Time) error {
|
||||
updateRequest := UpdateUserRequest{}
|
||||
if user.FailedLoginCount >= s.config.LoginAttemptsWithoutPenalty-1 {
|
||||
user.LoginLockoutExpiration = lockoutExpDate
|
||||
updateRequest.LoginLockoutExpiration = &lockoutExpDate
|
||||
}
|
||||
user.FailedLoginCount++
|
||||
|
||||
return s.store.Users().Update(ctx, user)
|
||||
updateRequest.FailedLoginCount = &user.FailedLoginCount
|
||||
return s.store.Users().Update(ctx, user.ID, updateRequest)
|
||||
}
|
||||
|
||||
// GetUser returns User by id.
|
||||
@ -1161,7 +1176,11 @@ func (s *Service) UpdateAccount(ctx context.Context, fullName string, shortName
|
||||
|
||||
user.FullName = fullName
|
||||
user.ShortName = shortName
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
shortNamePtr := &user.ShortName
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
FullName: &user.FullName,
|
||||
ShortName: &shortNamePtr,
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -1190,7 +1209,9 @@ func (s *Service) ChangeEmail(ctx context.Context, newEmail string) (err error)
|
||||
}
|
||||
|
||||
user.Email = newEmail
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
Email: &user.Email,
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -1221,7 +1242,9 @@ func (s *Service) ChangePassword(ctx context.Context, pass, newPass string) (err
|
||||
}
|
||||
|
||||
user.PasswordHash = hash
|
||||
err = s.store.Users().Update(ctx, user)
|
||||
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
|
||||
PasswordHash: hash,
|
||||
})
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
@ -765,7 +765,7 @@ func TestLockAccount(t *testing.T) {
|
||||
// lock account once again and check if lockout expiration time increased.
|
||||
expDuration := time.Duration(math.Pow(consoleConfig.FailedLoginPenalty, float64(lockedUser.FailedLoginCount-1))) * time.Minute
|
||||
lockoutExpDate := now.Add(expDuration)
|
||||
err = service.UpdateUsersFailedLoginState(userCtx, lockedUser, lockoutExpDate)
|
||||
err = service.UpdateUsersFailedLoginState(userCtx, lockedUser, &lockoutExpDate)
|
||||
require.NoError(t, err)
|
||||
|
||||
lockedUser, err = service.GetUser(userCtx, user.ID)
|
||||
@ -777,7 +777,10 @@ func TestLockAccount(t *testing.T) {
|
||||
|
||||
// unlock account by successful login
|
||||
lockedUser.LoginLockoutExpiration = now.Add(-time.Second)
|
||||
err = usersDB.Update(userCtx, lockedUser)
|
||||
lockoutExpirationPtr := &lockedUser.LoginLockoutExpiration
|
||||
err = usersDB.Update(userCtx, lockedUser.ID, console.UpdateUserRequest{
|
||||
LoginLockoutExpiration: &lockoutExpirationPtr,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authUser.Password = newUser.FullName
|
||||
@ -808,8 +811,12 @@ func TestLockAccount(t *testing.T) {
|
||||
|
||||
// unlock account
|
||||
lockedUser.LoginLockoutExpiration = time.Time{}
|
||||
lockoutExpirationPtr = &lockedUser.LoginLockoutExpiration
|
||||
lockedUser.FailedLoginCount = 0
|
||||
err = usersDB.Update(userCtx, lockedUser)
|
||||
err = usersDB.Update(userCtx, lockedUser.ID, console.UpdateUserRequest{
|
||||
LoginLockoutExpiration: &lockoutExpirationPtr,
|
||||
FailedLoginCount: &lockedUser.FailedLoginCount,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// check if user's account gets locked because of providing wrong mfa recovery code.
|
||||
|
@ -33,7 +33,7 @@ type Users interface {
|
||||
// Delete is a method for deleting user by Id from the database.
|
||||
Delete(ctx context.Context, id uuid.UUID) error
|
||||
// Update is a method for updating user entity.
|
||||
Update(ctx context.Context, user *User) error
|
||||
Update(ctx context.Context, userID uuid.UUID, request UpdateUserRequest) error
|
||||
// UpdatePaidTier sets whether the user is in the paid tier.
|
||||
UpdatePaidTier(ctx context.Context, id uuid.UUID, paidTier bool, projectBandwidthLimit, projectStorageLimit memory.Size, projectSegmentLimit int64, projectLimit int) error
|
||||
// GetProjectLimit is a method to get the users project limit
|
||||
@ -226,3 +226,32 @@ func GetUser(ctx context.Context) (*User, error) {
|
||||
|
||||
return nil, Error.New("user is not in context")
|
||||
}
|
||||
|
||||
// UpdateUserRequest contains all columns which are optionally updatable by users.Update.
|
||||
type UpdateUserRequest struct {
|
||||
FullName *string
|
||||
ShortName **string
|
||||
|
||||
Email *string
|
||||
PasswordHash []byte
|
||||
|
||||
Status *UserStatus
|
||||
|
||||
ProjectLimit *int
|
||||
ProjectStorageLimit *int64
|
||||
ProjectBandwidthLimit *int64
|
||||
ProjectSegmentLimit *int64
|
||||
PaidTier *bool
|
||||
|
||||
MFAEnabled *bool
|
||||
MFASecretKey **string
|
||||
MFARecoveryCodes *[]string
|
||||
|
||||
LastVerificationReminder **time.Time
|
||||
|
||||
// failed_login_count is nullable, but we don't really have a reason
|
||||
// to set it to NULL, so it doesn't need to be a double pointer here.
|
||||
FailedLoginCount *int
|
||||
|
||||
LoginLockoutExpiration **time.Time
|
||||
}
|
||||
|
@ -114,7 +114,9 @@ func TestUserEmailCase(t *testing.T) {
|
||||
|
||||
createdUser.Status = console.Active
|
||||
|
||||
err = db.Console().Users().Update(ctx, createdUser)
|
||||
err = db.Console().Users().Update(ctx, createdUser.ID, console.UpdateUserRequest{
|
||||
Status: &createdUser.Status,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrievedUser, err := db.Console().Users().GetByEmail(ctx, testCase.email)
|
||||
@ -178,7 +180,9 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
||||
|
||||
insertedUser.Status = console.Active
|
||||
|
||||
err = repository.Update(ctx, insertedUser)
|
||||
err = repository.Update(ctx, insertedUser.ID, console.UpdateUserRequest{
|
||||
Status: &insertedUser.Status,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
@ -266,7 +270,22 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
||||
LastVerificationReminder: date,
|
||||
}
|
||||
|
||||
err = repository.Update(ctx, newUserInfo)
|
||||
shortNamePtr := &newUserInfo.ShortName
|
||||
secretKeyPtr := &newUserInfo.MFASecretKey
|
||||
lastVerificationReminderPtr := &newUserInfo.LastVerificationReminder
|
||||
|
||||
err = repository.Update(ctx, newUserInfo.ID, console.UpdateUserRequest{
|
||||
FullName: &newUserInfo.FullName,
|
||||
ShortName: &shortNamePtr,
|
||||
Email: &newUserInfo.Email,
|
||||
Status: &newUserInfo.Status,
|
||||
PaidTier: &newUserInfo.PaidTier,
|
||||
MFAEnabled: &newUserInfo.MFAEnabled,
|
||||
MFASecretKey: &secretKeyPtr,
|
||||
MFARecoveryCodes: &newUserInfo.MFARecoveryCodes,
|
||||
PasswordHash: newUserInfo.PasswordHash,
|
||||
LastVerificationReminder: &lastVerificationReminderPtr,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
newUser, err := repository.Get(ctx, oldUser.ID)
|
||||
@ -333,7 +352,9 @@ func TestGetUserByEmail(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Required to set the active status.
|
||||
err = usersRepo.Update(ctx, &activeUser)
|
||||
err = usersRepo.Update(ctx, activeUser.ID, console.UpdateUserRequest{
|
||||
Status: &activeUser.Status,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dbUser, err := usersRepo.GetByEmail(ctx, email)
|
||||
|
@ -181,17 +181,17 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) (err error) {
|
||||
}
|
||||
|
||||
// Update is a method for updating user entity.
|
||||
func (users *users) Update(ctx context.Context, user *console.User) (err error) {
|
||||
func (users *users) Update(ctx context.Context, userID uuid.UUID, updateRequest console.UpdateUserRequest) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
updateFields, err := toUpdateUser(user)
|
||||
updateFields, err := toUpdateUser(updateRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = users.db.Update_User_By_Id(
|
||||
ctx,
|
||||
dbx.User_Id(user.ID[:]),
|
||||
dbx.User_Id(userID[:]),
|
||||
*updateFields,
|
||||
)
|
||||
|
||||
@ -251,34 +251,82 @@ func (users *users) GetUserPaidTier(ctx context.Context, id uuid.UUID) (isPaid b
|
||||
}
|
||||
|
||||
// toUpdateUser creates dbx.User_Update_Fields with only non-empty fields as updatable.
|
||||
func toUpdateUser(user *console.User) (*dbx.User_Update_Fields, error) {
|
||||
update := dbx.User_Update_Fields{
|
||||
FullName: dbx.User_FullName(user.FullName),
|
||||
ShortName: dbx.User_ShortName(user.ShortName),
|
||||
Email: dbx.User_Email(user.Email),
|
||||
NormalizedEmail: dbx.User_NormalizedEmail(normalizeEmail(user.Email)),
|
||||
Status: dbx.User_Status(int(user.Status)),
|
||||
ProjectLimit: dbx.User_ProjectLimit(user.ProjectLimit),
|
||||
ProjectStorageLimit: dbx.User_ProjectStorageLimit(user.ProjectStorageLimit),
|
||||
ProjectBandwidthLimit: dbx.User_ProjectBandwidthLimit(user.ProjectBandwidthLimit),
|
||||
ProjectSegmentLimit: dbx.User_ProjectSegmentLimit(user.ProjectSegmentLimit),
|
||||
PaidTier: dbx.User_PaidTier(user.PaidTier),
|
||||
MfaEnabled: dbx.User_MfaEnabled(user.MFAEnabled),
|
||||
LastVerificationReminder: dbx.User_LastVerificationReminder(user.LastVerificationReminder),
|
||||
FailedLoginCount: dbx.User_FailedLoginCount(user.FailedLoginCount),
|
||||
LoginLockoutExpiration: dbx.User_LoginLockoutExpiration(user.LoginLockoutExpiration),
|
||||
func toUpdateUser(request console.UpdateUserRequest) (*dbx.User_Update_Fields, error) {
|
||||
update := dbx.User_Update_Fields{}
|
||||
if request.FullName != nil {
|
||||
update.FullName = dbx.User_FullName(*request.FullName)
|
||||
}
|
||||
|
||||
recoveryBytes, err := json.Marshal(user.MFARecoveryCodes)
|
||||
if request.ShortName != nil {
|
||||
if *request.ShortName == nil {
|
||||
update.ShortName = dbx.User_ShortName_Null()
|
||||
} else {
|
||||
update.ShortName = dbx.User_ShortName(**request.ShortName)
|
||||
}
|
||||
}
|
||||
if request.Email != nil {
|
||||
update.Email = dbx.User_Email(*request.Email)
|
||||
update.NormalizedEmail = dbx.User_NormalizedEmail(normalizeEmail(*request.Email))
|
||||
}
|
||||
if request.PasswordHash != nil {
|
||||
if len(request.PasswordHash) > 0 {
|
||||
update.PasswordHash = dbx.User_PasswordHash(request.PasswordHash)
|
||||
}
|
||||
}
|
||||
if request.Status != nil {
|
||||
update.Status = dbx.User_Status(int(*request.Status))
|
||||
}
|
||||
if request.ProjectLimit != nil {
|
||||
update.ProjectLimit = dbx.User_ProjectLimit(*request.ProjectLimit)
|
||||
}
|
||||
if request.ProjectStorageLimit != nil {
|
||||
update.ProjectStorageLimit = dbx.User_ProjectStorageLimit(*request.ProjectStorageLimit)
|
||||
}
|
||||
if request.ProjectBandwidthLimit != nil {
|
||||
update.ProjectBandwidthLimit = dbx.User_ProjectBandwidthLimit(*request.ProjectBandwidthLimit)
|
||||
}
|
||||
if request.ProjectSegmentLimit != nil {
|
||||
update.ProjectSegmentLimit = dbx.User_ProjectSegmentLimit(*request.ProjectSegmentLimit)
|
||||
}
|
||||
if request.PaidTier != nil {
|
||||
update.PaidTier = dbx.User_PaidTier(*request.PaidTier)
|
||||
}
|
||||
if request.MFAEnabled != nil {
|
||||
update.MfaEnabled = dbx.User_MfaEnabled(*request.MFAEnabled)
|
||||
}
|
||||
if request.MFASecretKey != nil {
|
||||
if *request.MFASecretKey == nil {
|
||||
update.MfaSecretKey = dbx.User_MfaSecretKey_Null()
|
||||
} else {
|
||||
update.MfaSecretKey = dbx.User_MfaSecretKey(**request.MFASecretKey)
|
||||
}
|
||||
}
|
||||
if request.MFARecoveryCodes != nil {
|
||||
if *request.MFARecoveryCodes == nil {
|
||||
update.MfaRecoveryCodes = dbx.User_MfaRecoveryCodes_Null()
|
||||
} else {
|
||||
recoveryBytes, err := json.Marshal(*request.MFARecoveryCodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
update.MfaRecoveryCodes = dbx.User_MfaRecoveryCodes(string(recoveryBytes))
|
||||
update.MfaSecretKey = dbx.User_MfaSecretKey(user.MFASecretKey)
|
||||
|
||||
// extra password check to update only calculated hash from service
|
||||
if len(user.PasswordHash) != 0 {
|
||||
update.PasswordHash = dbx.User_PasswordHash(user.PasswordHash)
|
||||
}
|
||||
}
|
||||
if request.LastVerificationReminder != nil {
|
||||
if *request.LastVerificationReminder == nil {
|
||||
update.LastVerificationReminder = dbx.User_LastVerificationReminder_Null()
|
||||
} else {
|
||||
update.LastVerificationReminder = dbx.User_LastVerificationReminder(**request.LastVerificationReminder)
|
||||
}
|
||||
}
|
||||
if request.FailedLoginCount != nil {
|
||||
update.FailedLoginCount = dbx.User_FailedLoginCount(*request.FailedLoginCount)
|
||||
}
|
||||
if request.LoginLockoutExpiration != nil {
|
||||
if *request.LoginLockoutExpiration == nil {
|
||||
update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration_Null()
|
||||
} else {
|
||||
update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration(**request.LoginLockoutExpiration)
|
||||
}
|
||||
}
|
||||
|
||||
return &update, nil
|
||||
|
@ -54,3 +54,250 @@ func TestGetUnverifiedNeedingReminderCutoff(t *testing.T) {
|
||||
require.Len(t, needingReminder, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||||
users := db.Console().Users()
|
||||
id := testrand.UUID()
|
||||
u, err := users.Insert(ctx, &console.User{
|
||||
ID: id,
|
||||
FullName: "testFullName",
|
||||
ShortName: "testShortName",
|
||||
Email: "test@storj.test",
|
||||
PasswordHash: []byte("testPasswordHash"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newInfo := console.User{
|
||||
FullName: "updatedFullName",
|
||||
ShortName: "updatedShortName",
|
||||
PasswordHash: []byte("updatedPasswordHash"),
|
||||
ProjectLimit: 1,
|
||||
ProjectBandwidthLimit: 1,
|
||||
ProjectStorageLimit: 1,
|
||||
ProjectSegmentLimit: 1,
|
||||
PaidTier: true,
|
||||
MFAEnabled: true,
|
||||
MFASecretKey: "secretKey",
|
||||
MFARecoveryCodes: []string{"code1", "code2"},
|
||||
LastVerificationReminder: time.Now().Truncate(time.Second),
|
||||
FailedLoginCount: 1,
|
||||
LoginLockoutExpiration: time.Now().Truncate(time.Second),
|
||||
}
|
||||
|
||||
require.NotEqual(t, u.FullName, newInfo.FullName)
|
||||
require.NotEqual(t, u.ShortName, newInfo.ShortName)
|
||||
require.NotEqual(t, u.PasswordHash, newInfo.PasswordHash)
|
||||
require.NotEqual(t, u.ProjectLimit, newInfo.ProjectLimit)
|
||||
require.NotEqual(t, u.ProjectBandwidthLimit, newInfo.ProjectBandwidthLimit)
|
||||
require.NotEqual(t, u.ProjectStorageLimit, newInfo.ProjectStorageLimit)
|
||||
require.NotEqual(t, u.ProjectSegmentLimit, newInfo.ProjectSegmentLimit)
|
||||
require.NotEqual(t, u.PaidTier, newInfo.PaidTier)
|
||||
require.NotEqual(t, u.MFAEnabled, newInfo.MFAEnabled)
|
||||
require.NotEqual(t, u.MFASecretKey, newInfo.MFASecretKey)
|
||||
require.NotEqual(t, u.MFARecoveryCodes, newInfo.MFARecoveryCodes)
|
||||
require.NotEqual(t, u.LastVerificationReminder, newInfo.LastVerificationReminder)
|
||||
require.NotEqual(t, u.FailedLoginCount, newInfo.FailedLoginCount)
|
||||
require.NotEqual(t, u.LoginLockoutExpiration, newInfo.LoginLockoutExpiration)
|
||||
|
||||
// update just fullname
|
||||
updateReq := console.UpdateUserRequest{
|
||||
FullName: &newInfo.FullName,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.FullName = newInfo.FullName
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just shortname
|
||||
shortNamePtr := &newInfo.ShortName
|
||||
updateReq = console.UpdateUserRequest{
|
||||
ShortName: &shortNamePtr,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.ShortName = newInfo.ShortName
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just password hash
|
||||
updateReq = console.UpdateUserRequest{
|
||||
PasswordHash: newInfo.PasswordHash,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.PasswordHash = newInfo.PasswordHash
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just project limit
|
||||
updateReq = console.UpdateUserRequest{
|
||||
ProjectLimit: &newInfo.ProjectLimit,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.ProjectLimit = newInfo.ProjectLimit
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just project bw limit
|
||||
updateReq = console.UpdateUserRequest{
|
||||
ProjectBandwidthLimit: &newInfo.ProjectBandwidthLimit,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.ProjectBandwidthLimit = newInfo.ProjectBandwidthLimit
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just project storage limit
|
||||
updateReq = console.UpdateUserRequest{
|
||||
ProjectStorageLimit: &newInfo.ProjectStorageLimit,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.ProjectStorageLimit = newInfo.ProjectStorageLimit
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just project segment limit
|
||||
updateReq = console.UpdateUserRequest{
|
||||
ProjectSegmentLimit: &newInfo.ProjectSegmentLimit,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.ProjectSegmentLimit = newInfo.ProjectSegmentLimit
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just paid tier
|
||||
updateReq = console.UpdateUserRequest{
|
||||
PaidTier: &newInfo.PaidTier,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.PaidTier = newInfo.PaidTier
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just mfa enabled
|
||||
updateReq = console.UpdateUserRequest{
|
||||
MFAEnabled: &newInfo.MFAEnabled,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.MFAEnabled = newInfo.MFAEnabled
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just mfa secret key
|
||||
secretKeyPtr := &newInfo.MFASecretKey
|
||||
updateReq = console.UpdateUserRequest{
|
||||
MFASecretKey: &secretKeyPtr,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.MFASecretKey = newInfo.MFASecretKey
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just mfa recovery codes
|
||||
updateReq = console.UpdateUserRequest{
|
||||
MFARecoveryCodes: &newInfo.MFARecoveryCodes,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.MFARecoveryCodes = newInfo.MFARecoveryCodes
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just last verification reminder
|
||||
lastReminderPtr := &newInfo.LastVerificationReminder
|
||||
updateReq = console.UpdateUserRequest{
|
||||
LastVerificationReminder: &lastReminderPtr,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.LastVerificationReminder = newInfo.LastVerificationReminder
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just failed login count
|
||||
updateReq = console.UpdateUserRequest{
|
||||
FailedLoginCount: &newInfo.FailedLoginCount,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.FailedLoginCount = newInfo.FailedLoginCount
|
||||
require.Equal(t, u, updatedUser)
|
||||
|
||||
// update just login lockout expiration
|
||||
loginLockoutExpPtr := &newInfo.LoginLockoutExpiration
|
||||
updateReq = console.UpdateUserRequest{
|
||||
LoginLockoutExpiration: &loginLockoutExpPtr,
|
||||
}
|
||||
|
||||
err = users.Update(ctx, id, updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err = users.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.LoginLockoutExpiration = newInfo.LoginLockoutExpiration
|
||||
require.Equal(t, u, updatedUser)
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user