satellite/console: use new type UpdateUserRequest as arg to db users.Update

The users.Update method in the satellitedb package takes a console.User
as an argument. It reads some of the fields on this struct and assigns
the value to dbx.User_Update_Fields. However, you cannot optionally
update only some of the fields. They all will always be updated. This means
that if you only want to update FullName, you still need to read the
user info from the DB to avoid updating the rest of the fields to zero.
This is not good because concurrent updates can overwrite each other.

This change introduces a new struct type, UpdateUserRequest, which
contains pointers for all the fields that are updated by satellite db
users.Update. Now the update method will check if a field is nil before
assigning the value to be updated in the db, so you only need to set the
field you want updated. For nullable columns, the respective field is a
double pointer. This allows us to update a column to NULL if the outer
pointer is not nil, but the inner pointer is.

Change-Id: I27f842d283c2711e24d51dcab622e57eeb9157f1
This commit is contained in:
Cameron 2022-06-01 17:15:37 -04:00
parent 3ae325462c
commit 240b70b828
11 changed files with 501 additions and 89 deletions

View File

@ -104,8 +104,9 @@ func (server *Server) addUser(w http.ResponseWriter, r *http.Request) {
// Set User Status to be activated, as we manually created it
newUser.Status = console.Active
newUser.PasswordHash = nil
err = server.db.Console().Users().Update(ctx, newUser)
err = server.db.Console().Users().Update(ctx, userID, console.UpdateUserRequest{
Status: &newUser.Status,
})
if err != nil {
sendJSONError(w, "failed to activate user",
err.Error(), http.StatusInternalServerError)
@ -240,32 +241,32 @@ func (server *Server) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
updateRequest := console.UpdateUserRequest{}
if input.FullName != "" {
user.FullName = input.FullName
updateRequest.FullName = &input.FullName
}
if input.ShortName != "" {
user.ShortName = input.ShortName
shortNamePtr := &input.ShortName
updateRequest.ShortName = &shortNamePtr
}
if input.Email != "" {
user.Email = input.Email
}
if !input.PartnerID.IsZero() {
user.PartnerID = input.PartnerID
updateRequest.Email = &input.Email
}
if len(input.PasswordHash) > 0 {
user.PasswordHash = input.PasswordHash
updateRequest.PasswordHash = input.PasswordHash
}
if input.ProjectLimit > 0 {
user.ProjectLimit = input.ProjectLimit
updateRequest.ProjectLimit = &input.ProjectLimit
}
if input.ProjectStorageLimit > 0 {
user.ProjectStorageLimit = input.ProjectStorageLimit
updateRequest.ProjectStorageLimit = &input.ProjectStorageLimit
}
if input.ProjectBandwidthLimit > 0 {
user.ProjectBandwidthLimit = input.ProjectBandwidthLimit
updateRequest.ProjectBandwidthLimit = &input.ProjectBandwidthLimit
}
if input.ProjectSegmentLimit > 0 {
user.ProjectSegmentLimit = input.ProjectSegmentLimit
updateRequest.ProjectSegmentLimit = &input.ProjectSegmentLimit
}
if input.PaidTierStr != "" {
status, err := strconv.ParseBool(input.PaidTierStr)
@ -275,10 +276,10 @@ func (server *Server) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
user.PaidTier = status
updateRequest.PaidTier = &status
}
err = server.db.Console().Users().Update(ctx, user)
err = server.db.Console().Users().Update(ctx, user.ID, updateRequest)
if err != nil {
sendJSONError(w, "failed to update user",
err.Error(), http.StatusInternalServerError)
@ -310,9 +311,14 @@ func (server *Server) disableUserMFA(w http.ResponseWriter, r *http.Request) {
user.MFAEnabled = false
user.MFASecretKey = ""
user.MFARecoveryCodes = nil
mfaSecretKeyPtr := &user.MFASecretKey
var mfaRecoveryCodes []string
err = server.db.Console().Users().Update(ctx, user)
err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
MFAEnabled: &user.MFAEnabled,
MFASecretKey: &mfaSecretKeyPtr,
MFARecoveryCodes: &mfaRecoveryCodes,
})
if err != nil {
sendJSONError(w, "failed to disable mfa",
err.Error(), http.StatusInternalServerError)
@ -402,15 +408,17 @@ func (server *Server) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
userInfo := &console.User{
ID: user.ID,
FullName: "",
ShortName: "",
Email: fmt.Sprintf("deactivated+%s@storj.io", user.ID.String()),
Status: console.Deleted,
}
emptyName := ""
emptyNamePtr := &emptyName
deactivatedEmail := fmt.Sprintf("deactivated+%s@storj.io", user.ID.String())
status := console.Deleted
err = server.db.Console().Users().Update(ctx, userInfo)
err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
FullName: &emptyName,
ShortName: &emptyNamePtr,
Email: &deactivatedEmail,
Status: &status,
})
if err != nil {
sendJSONError(w, "unable to delete user",
err.Error(), http.StatusInternalServerError)

View File

@ -233,7 +233,14 @@ func TestDisableMFA(t *testing.T) {
user.MFAEnabled = true
user.MFASecretKey = "randomtext"
user.MFARecoveryCodes = []string{"0123456789"}
err = planet.Satellites[0].DB.Console().Users().Update(ctx, user)
secretKeyPtr := &user.MFASecretKey
err = planet.Satellites[0].DB.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
MFAEnabled: &user.MFAEnabled,
MFASecretKey: &secretKeyPtr,
MFARecoveryCodes: &user.MFARecoveryCodes,
})
require.NoError(t, err)
// Ensure MFA is enabled.

View File

@ -695,7 +695,9 @@ func TestRegistrationEmail(t *testing.T) {
require.NoError(t, err)
user.Status = console.Active
require.NoError(t, sat.DB.Console().Users().Update(ctx, user))
require.NoError(t, sat.DB.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
Status: &user.Status,
}))
newUserID = register()
require.Equal(t, userID, newUserID)
@ -740,7 +742,9 @@ func TestResendActivationEmail(t *testing.T) {
// Expect activation e-mail to be sent when using unverified e-mail address.
user.Status = console.Inactive
require.NoError(t, usersRepo.Update(ctx, user))
require.NoError(t, usersRepo.Update(ctx, user.ID, console.UpdateUserRequest{
Status: &user.Status,
}))
resendEmail()
body, err = sender.Data.Get(ctx)

View File

@ -77,7 +77,9 @@ func TestEmailChoreUpdatesVerificationReminders(t *testing.T) {
require.Zero(t, user3.VerificationReminders)
user1.Status = 1
err = users.Update(ctx, user1)
err = users.Update(ctx, user1.ID, console.UpdateUserRequest{
Status: &user1.Status,
})
require.NoError(t, err)
chore.Loop.TriggerWait()
@ -95,7 +97,9 @@ func TestEmailChoreUpdatesVerificationReminders(t *testing.T) {
require.Equal(t, 1, user3.VerificationReminders)
user2.Status = 1
err = users.Update(ctx, user2)
err = users.Update(ctx, user2.ID, console.UpdateUserRequest{
Status: &user2.Status,
})
require.NoError(t, err)
chore.Loop.TriggerWait()

View File

@ -100,7 +100,9 @@ func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Tim
}
user.MFAEnabled = true
err = s.store.Users().Update(ctx, user)
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
MFAEnabled: &user.MFAEnabled,
})
if err != nil {
return Error.Wrap(err)
}
@ -151,7 +153,14 @@ func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Ti
user.MFAEnabled = false
user.MFASecretKey = ""
user.MFARecoveryCodes = nil
err = s.store.Users().Update(ctx, user)
secretKeyPtr := &user.MFASecretKey
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
MFAEnabled: &user.MFAEnabled,
MFASecretKey: &secretKeyPtr,
MFARecoveryCodes: &user.MFARecoveryCodes,
})
if err != nil {
return Error.Wrap(err)
}
@ -194,7 +203,11 @@ func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error)
}
user.MFASecretKey = key
err = s.store.Users().Update(ctx, user)
mfaSecretKeyPtr := &user.MFASecretKey
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
MFASecretKey: &mfaSecretKeyPtr,
})
if err != nil {
return "", Error.Wrap(err)
}
@ -223,9 +236,10 @@ func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, er
}
codes[i] = code
}
user.MFARecoveryCodes = codes
err = s.store.Users().Update(ctx, user)
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
MFARecoveryCodes: &codes,
})
if err != nil {
return nil, Error.Wrap(err)
}

View File

@ -830,8 +830,10 @@ func (s *Service) ActivateAccount(ctx context.Context, activationToken string) (
return nil, Error.Wrap(err)
}
user.Status = Active
err = s.store.Users().Update(ctx, user)
status := Active
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
Status: &status,
})
if err != nil {
return nil, Error.Wrap(err)
}
@ -898,16 +900,22 @@ func (s *Service) ResetPassword(ctx context.Context, resetPasswordToken, passwor
return Error.Wrap(err)
}
user.PasswordHash = hash
if user.FailedLoginCount != 0 {
user.FailedLoginCount = 0
user.LoginLockoutExpiration = time.Time{}
updateRequest := UpdateUserRequest{
PasswordHash: hash,
}
err = s.store.Users().Update(ctx, user)
if user.FailedLoginCount != 0 {
resetFailedLoginCount := 0
resetLoginLockoutExpirationPtr := &time.Time{}
updateRequest.FailedLoginCount = &resetFailedLoginCount
updateRequest.LoginLockoutExpiration = &resetLoginLockoutExpirationPtr
}
err = s.store.Users().Update(ctx, user.ID, updateRequest)
if err != nil {
return Error.Wrap(err)
}
s.auditLog(ctx, "password reset", &user.ID, user.Email)
if err = s.store.ResetPasswordTokens().Delete(ctx, token.Secret); err != nil {
@ -954,7 +962,7 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
lockoutExpDate := now.Add(time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute)
handleLockAccount := func() error {
err = s.UpdateUsersFailedLoginState(ctx, user, lockoutExpDate)
err = s.UpdateUsersFailedLoginState(ctx, user, &lockoutExpDate)
if err != nil {
return err
}
@ -1014,7 +1022,9 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
user.MFARecoveryCodes = append(user.MFARecoveryCodes[:codeIndex], user.MFARecoveryCodes[codeIndex+1:]...)
err = s.store.Users().Update(ctx, user)
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
MFARecoveryCodes: &user.MFARecoveryCodes,
})
if err != nil {
return consoleauth.Token{}, err
}
@ -1045,8 +1055,11 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
if user.FailedLoginCount != 0 {
user.FailedLoginCount = 0
user.LoginLockoutExpiration = time.Time{}
err = s.store.Users().Update(ctx, user)
loginLockoutExpirationPtr := &time.Time{}
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
FailedLoginCount: &user.FailedLoginCount,
LoginLockoutExpiration: &loginLockoutExpirationPtr,
})
if err != nil {
return consoleauth.Token{}, err
}
@ -1063,13 +1076,15 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleaut
}
// UpdateUsersFailedLoginState updates User's failed login state.
func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User, lockoutExpDate time.Time) error {
func (s *Service) UpdateUsersFailedLoginState(ctx context.Context, user *User, lockoutExpDate *time.Time) error {
updateRequest := UpdateUserRequest{}
if user.FailedLoginCount >= s.config.LoginAttemptsWithoutPenalty-1 {
user.LoginLockoutExpiration = lockoutExpDate
updateRequest.LoginLockoutExpiration = &lockoutExpDate
}
user.FailedLoginCount++
return s.store.Users().Update(ctx, user)
updateRequest.FailedLoginCount = &user.FailedLoginCount
return s.store.Users().Update(ctx, user.ID, updateRequest)
}
// GetUser returns User by id.
@ -1161,7 +1176,11 @@ func (s *Service) UpdateAccount(ctx context.Context, fullName string, shortName
user.FullName = fullName
user.ShortName = shortName
err = s.store.Users().Update(ctx, user)
shortNamePtr := &user.ShortName
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
FullName: &user.FullName,
ShortName: &shortNamePtr,
})
if err != nil {
return Error.Wrap(err)
}
@ -1190,7 +1209,9 @@ func (s *Service) ChangeEmail(ctx context.Context, newEmail string) (err error)
}
user.Email = newEmail
err = s.store.Users().Update(ctx, user)
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
Email: &user.Email,
})
if err != nil {
return Error.Wrap(err)
}
@ -1221,7 +1242,9 @@ func (s *Service) ChangePassword(ctx context.Context, pass, newPass string) (err
}
user.PasswordHash = hash
err = s.store.Users().Update(ctx, user)
err = s.store.Users().Update(ctx, user.ID, UpdateUserRequest{
PasswordHash: hash,
})
if err != nil {
return Error.Wrap(err)
}

View File

@ -765,7 +765,7 @@ func TestLockAccount(t *testing.T) {
// lock account once again and check if lockout expiration time increased.
expDuration := time.Duration(math.Pow(consoleConfig.FailedLoginPenalty, float64(lockedUser.FailedLoginCount-1))) * time.Minute
lockoutExpDate := now.Add(expDuration)
err = service.UpdateUsersFailedLoginState(userCtx, lockedUser, lockoutExpDate)
err = service.UpdateUsersFailedLoginState(userCtx, lockedUser, &lockoutExpDate)
require.NoError(t, err)
lockedUser, err = service.GetUser(userCtx, user.ID)
@ -777,7 +777,10 @@ func TestLockAccount(t *testing.T) {
// unlock account by successful login
lockedUser.LoginLockoutExpiration = now.Add(-time.Second)
err = usersDB.Update(userCtx, lockedUser)
lockoutExpirationPtr := &lockedUser.LoginLockoutExpiration
err = usersDB.Update(userCtx, lockedUser.ID, console.UpdateUserRequest{
LoginLockoutExpiration: &lockoutExpirationPtr,
})
require.NoError(t, err)
authUser.Password = newUser.FullName
@ -808,8 +811,12 @@ func TestLockAccount(t *testing.T) {
// unlock account
lockedUser.LoginLockoutExpiration = time.Time{}
lockoutExpirationPtr = &lockedUser.LoginLockoutExpiration
lockedUser.FailedLoginCount = 0
err = usersDB.Update(userCtx, lockedUser)
err = usersDB.Update(userCtx, lockedUser.ID, console.UpdateUserRequest{
LoginLockoutExpiration: &lockoutExpirationPtr,
FailedLoginCount: &lockedUser.FailedLoginCount,
})
require.NoError(t, err)
// check if user's account gets locked because of providing wrong mfa recovery code.

View File

@ -33,7 +33,7 @@ type Users interface {
// Delete is a method for deleting user by Id from the database.
Delete(ctx context.Context, id uuid.UUID) error
// Update is a method for updating user entity.
Update(ctx context.Context, user *User) error
Update(ctx context.Context, userID uuid.UUID, request UpdateUserRequest) error
// UpdatePaidTier sets whether the user is in the paid tier.
UpdatePaidTier(ctx context.Context, id uuid.UUID, paidTier bool, projectBandwidthLimit, projectStorageLimit memory.Size, projectSegmentLimit int64, projectLimit int) error
// GetProjectLimit is a method to get the users project limit
@ -226,3 +226,32 @@ func GetUser(ctx context.Context) (*User, error) {
return nil, Error.New("user is not in context")
}
// UpdateUserRequest contains all columns which are optionally updatable by users.Update.
type UpdateUserRequest struct {
FullName *string
ShortName **string
Email *string
PasswordHash []byte
Status *UserStatus
ProjectLimit *int
ProjectStorageLimit *int64
ProjectBandwidthLimit *int64
ProjectSegmentLimit *int64
PaidTier *bool
MFAEnabled *bool
MFASecretKey **string
MFARecoveryCodes *[]string
LastVerificationReminder **time.Time
// failed_login_count is nullable, but we don't really have a reason
// to set it to NULL, so it doesn't need to be a double pointer here.
FailedLoginCount *int
LoginLockoutExpiration **time.Time
}

View File

@ -114,7 +114,9 @@ func TestUserEmailCase(t *testing.T) {
createdUser.Status = console.Active
err = db.Console().Users().Update(ctx, createdUser)
err = db.Console().Users().Update(ctx, createdUser.ID, console.UpdateUserRequest{
Status: &createdUser.Status,
})
assert.NoError(t, err)
retrievedUser, err := db.Console().Users().GetByEmail(ctx, testCase.email)
@ -178,7 +180,9 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
insertedUser.Status = console.Active
err = repository.Update(ctx, insertedUser)
err = repository.Update(ctx, insertedUser.ID, console.UpdateUserRequest{
Status: &insertedUser.Status,
})
assert.NoError(t, err)
})
@ -266,7 +270,22 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
LastVerificationReminder: date,
}
err = repository.Update(ctx, newUserInfo)
shortNamePtr := &newUserInfo.ShortName
secretKeyPtr := &newUserInfo.MFASecretKey
lastVerificationReminderPtr := &newUserInfo.LastVerificationReminder
err = repository.Update(ctx, newUserInfo.ID, console.UpdateUserRequest{
FullName: &newUserInfo.FullName,
ShortName: &shortNamePtr,
Email: &newUserInfo.Email,
Status: &newUserInfo.Status,
PaidTier: &newUserInfo.PaidTier,
MFAEnabled: &newUserInfo.MFAEnabled,
MFASecretKey: &secretKeyPtr,
MFARecoveryCodes: &newUserInfo.MFARecoveryCodes,
PasswordHash: newUserInfo.PasswordHash,
LastVerificationReminder: &lastVerificationReminderPtr,
})
assert.NoError(t, err)
newUser, err := repository.Get(ctx, oldUser.ID)
@ -333,7 +352,9 @@ func TestGetUserByEmail(t *testing.T) {
require.NoError(t, err)
// Required to set the active status.
err = usersRepo.Update(ctx, &activeUser)
err = usersRepo.Update(ctx, activeUser.ID, console.UpdateUserRequest{
Status: &activeUser.Status,
})
require.NoError(t, err)
dbUser, err := usersRepo.GetByEmail(ctx, email)

View File

@ -181,17 +181,17 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) (err error) {
}
// Update is a method for updating user entity.
func (users *users) Update(ctx context.Context, user *console.User) (err error) {
func (users *users) Update(ctx context.Context, userID uuid.UUID, updateRequest console.UpdateUserRequest) (err error) {
defer mon.Task()(&ctx)(&err)
updateFields, err := toUpdateUser(user)
updateFields, err := toUpdateUser(updateRequest)
if err != nil {
return err
}
_, err = users.db.Update_User_By_Id(
ctx,
dbx.User_Id(user.ID[:]),
dbx.User_Id(userID[:]),
*updateFields,
)
@ -251,34 +251,82 @@ func (users *users) GetUserPaidTier(ctx context.Context, id uuid.UUID) (isPaid b
}
// toUpdateUser creates dbx.User_Update_Fields with only non-empty fields as updatable.
func toUpdateUser(user *console.User) (*dbx.User_Update_Fields, error) {
update := dbx.User_Update_Fields{
FullName: dbx.User_FullName(user.FullName),
ShortName: dbx.User_ShortName(user.ShortName),
Email: dbx.User_Email(user.Email),
NormalizedEmail: dbx.User_NormalizedEmail(normalizeEmail(user.Email)),
Status: dbx.User_Status(int(user.Status)),
ProjectLimit: dbx.User_ProjectLimit(user.ProjectLimit),
ProjectStorageLimit: dbx.User_ProjectStorageLimit(user.ProjectStorageLimit),
ProjectBandwidthLimit: dbx.User_ProjectBandwidthLimit(user.ProjectBandwidthLimit),
ProjectSegmentLimit: dbx.User_ProjectSegmentLimit(user.ProjectSegmentLimit),
PaidTier: dbx.User_PaidTier(user.PaidTier),
MfaEnabled: dbx.User_MfaEnabled(user.MFAEnabled),
LastVerificationReminder: dbx.User_LastVerificationReminder(user.LastVerificationReminder),
FailedLoginCount: dbx.User_FailedLoginCount(user.FailedLoginCount),
LoginLockoutExpiration: dbx.User_LoginLockoutExpiration(user.LoginLockoutExpiration),
func toUpdateUser(request console.UpdateUserRequest) (*dbx.User_Update_Fields, error) {
update := dbx.User_Update_Fields{}
if request.FullName != nil {
update.FullName = dbx.User_FullName(*request.FullName)
}
recoveryBytes, err := json.Marshal(user.MFARecoveryCodes)
if request.ShortName != nil {
if *request.ShortName == nil {
update.ShortName = dbx.User_ShortName_Null()
} else {
update.ShortName = dbx.User_ShortName(**request.ShortName)
}
}
if request.Email != nil {
update.Email = dbx.User_Email(*request.Email)
update.NormalizedEmail = dbx.User_NormalizedEmail(normalizeEmail(*request.Email))
}
if request.PasswordHash != nil {
if len(request.PasswordHash) > 0 {
update.PasswordHash = dbx.User_PasswordHash(request.PasswordHash)
}
}
if request.Status != nil {
update.Status = dbx.User_Status(int(*request.Status))
}
if request.ProjectLimit != nil {
update.ProjectLimit = dbx.User_ProjectLimit(*request.ProjectLimit)
}
if request.ProjectStorageLimit != nil {
update.ProjectStorageLimit = dbx.User_ProjectStorageLimit(*request.ProjectStorageLimit)
}
if request.ProjectBandwidthLimit != nil {
update.ProjectBandwidthLimit = dbx.User_ProjectBandwidthLimit(*request.ProjectBandwidthLimit)
}
if request.ProjectSegmentLimit != nil {
update.ProjectSegmentLimit = dbx.User_ProjectSegmentLimit(*request.ProjectSegmentLimit)
}
if request.PaidTier != nil {
update.PaidTier = dbx.User_PaidTier(*request.PaidTier)
}
if request.MFAEnabled != nil {
update.MfaEnabled = dbx.User_MfaEnabled(*request.MFAEnabled)
}
if request.MFASecretKey != nil {
if *request.MFASecretKey == nil {
update.MfaSecretKey = dbx.User_MfaSecretKey_Null()
} else {
update.MfaSecretKey = dbx.User_MfaSecretKey(**request.MFASecretKey)
}
}
if request.MFARecoveryCodes != nil {
if *request.MFARecoveryCodes == nil {
update.MfaRecoveryCodes = dbx.User_MfaRecoveryCodes_Null()
} else {
recoveryBytes, err := json.Marshal(*request.MFARecoveryCodes)
if err != nil {
return nil, err
}
update.MfaRecoveryCodes = dbx.User_MfaRecoveryCodes(string(recoveryBytes))
update.MfaSecretKey = dbx.User_MfaSecretKey(user.MFASecretKey)
// extra password check to update only calculated hash from service
if len(user.PasswordHash) != 0 {
update.PasswordHash = dbx.User_PasswordHash(user.PasswordHash)
}
}
if request.LastVerificationReminder != nil {
if *request.LastVerificationReminder == nil {
update.LastVerificationReminder = dbx.User_LastVerificationReminder_Null()
} else {
update.LastVerificationReminder = dbx.User_LastVerificationReminder(**request.LastVerificationReminder)
}
}
if request.FailedLoginCount != nil {
update.FailedLoginCount = dbx.User_FailedLoginCount(*request.FailedLoginCount)
}
if request.LoginLockoutExpiration != nil {
if *request.LoginLockoutExpiration == nil {
update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration_Null()
} else {
update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration(**request.LoginLockoutExpiration)
}
}
return &update, nil

View File

@ -54,3 +54,250 @@ func TestGetUnverifiedNeedingReminderCutoff(t *testing.T) {
require.Len(t, needingReminder, 1)
})
}
func TestUpdateUser(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
users := db.Console().Users()
id := testrand.UUID()
u, err := users.Insert(ctx, &console.User{
ID: id,
FullName: "testFullName",
ShortName: "testShortName",
Email: "test@storj.test",
PasswordHash: []byte("testPasswordHash"),
})
require.NoError(t, err)
newInfo := console.User{
FullName: "updatedFullName",
ShortName: "updatedShortName",
PasswordHash: []byte("updatedPasswordHash"),
ProjectLimit: 1,
ProjectBandwidthLimit: 1,
ProjectStorageLimit: 1,
ProjectSegmentLimit: 1,
PaidTier: true,
MFAEnabled: true,
MFASecretKey: "secretKey",
MFARecoveryCodes: []string{"code1", "code2"},
LastVerificationReminder: time.Now().Truncate(time.Second),
FailedLoginCount: 1,
LoginLockoutExpiration: time.Now().Truncate(time.Second),
}
require.NotEqual(t, u.FullName, newInfo.FullName)
require.NotEqual(t, u.ShortName, newInfo.ShortName)
require.NotEqual(t, u.PasswordHash, newInfo.PasswordHash)
require.NotEqual(t, u.ProjectLimit, newInfo.ProjectLimit)
require.NotEqual(t, u.ProjectBandwidthLimit, newInfo.ProjectBandwidthLimit)
require.NotEqual(t, u.ProjectStorageLimit, newInfo.ProjectStorageLimit)
require.NotEqual(t, u.ProjectSegmentLimit, newInfo.ProjectSegmentLimit)
require.NotEqual(t, u.PaidTier, newInfo.PaidTier)
require.NotEqual(t, u.MFAEnabled, newInfo.MFAEnabled)
require.NotEqual(t, u.MFASecretKey, newInfo.MFASecretKey)
require.NotEqual(t, u.MFARecoveryCodes, newInfo.MFARecoveryCodes)
require.NotEqual(t, u.LastVerificationReminder, newInfo.LastVerificationReminder)
require.NotEqual(t, u.FailedLoginCount, newInfo.FailedLoginCount)
require.NotEqual(t, u.LoginLockoutExpiration, newInfo.LoginLockoutExpiration)
// update just fullname
updateReq := console.UpdateUserRequest{
FullName: &newInfo.FullName,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err := users.Get(ctx, id)
require.NoError(t, err)
u.FullName = newInfo.FullName
require.Equal(t, u, updatedUser)
// update just shortname
shortNamePtr := &newInfo.ShortName
updateReq = console.UpdateUserRequest{
ShortName: &shortNamePtr,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.ShortName = newInfo.ShortName
require.Equal(t, u, updatedUser)
// update just password hash
updateReq = console.UpdateUserRequest{
PasswordHash: newInfo.PasswordHash,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.PasswordHash = newInfo.PasswordHash
require.Equal(t, u, updatedUser)
// update just project limit
updateReq = console.UpdateUserRequest{
ProjectLimit: &newInfo.ProjectLimit,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.ProjectLimit = newInfo.ProjectLimit
require.Equal(t, u, updatedUser)
// update just project bw limit
updateReq = console.UpdateUserRequest{
ProjectBandwidthLimit: &newInfo.ProjectBandwidthLimit,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.ProjectBandwidthLimit = newInfo.ProjectBandwidthLimit
require.Equal(t, u, updatedUser)
// update just project storage limit
updateReq = console.UpdateUserRequest{
ProjectStorageLimit: &newInfo.ProjectStorageLimit,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.ProjectStorageLimit = newInfo.ProjectStorageLimit
require.Equal(t, u, updatedUser)
// update just project segment limit
updateReq = console.UpdateUserRequest{
ProjectSegmentLimit: &newInfo.ProjectSegmentLimit,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.ProjectSegmentLimit = newInfo.ProjectSegmentLimit
require.Equal(t, u, updatedUser)
// update just paid tier
updateReq = console.UpdateUserRequest{
PaidTier: &newInfo.PaidTier,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.PaidTier = newInfo.PaidTier
require.Equal(t, u, updatedUser)
// update just mfa enabled
updateReq = console.UpdateUserRequest{
MFAEnabled: &newInfo.MFAEnabled,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.MFAEnabled = newInfo.MFAEnabled
require.Equal(t, u, updatedUser)
// update just mfa secret key
secretKeyPtr := &newInfo.MFASecretKey
updateReq = console.UpdateUserRequest{
MFASecretKey: &secretKeyPtr,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.MFASecretKey = newInfo.MFASecretKey
require.Equal(t, u, updatedUser)
// update just mfa recovery codes
updateReq = console.UpdateUserRequest{
MFARecoveryCodes: &newInfo.MFARecoveryCodes,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.MFARecoveryCodes = newInfo.MFARecoveryCodes
require.Equal(t, u, updatedUser)
// update just last verification reminder
lastReminderPtr := &newInfo.LastVerificationReminder
updateReq = console.UpdateUserRequest{
LastVerificationReminder: &lastReminderPtr,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.LastVerificationReminder = newInfo.LastVerificationReminder
require.Equal(t, u, updatedUser)
// update just failed login count
updateReq = console.UpdateUserRequest{
FailedLoginCount: &newInfo.FailedLoginCount,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.FailedLoginCount = newInfo.FailedLoginCount
require.Equal(t, u, updatedUser)
// update just login lockout expiration
loginLockoutExpPtr := &newInfo.LoginLockoutExpiration
updateReq = console.UpdateUserRequest{
LoginLockoutExpiration: &loginLockoutExpPtr,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.LoginLockoutExpiration = newInfo.LoginLockoutExpiration
require.Equal(t, u, updatedUser)
})
}