satellite/satellitedb: fix DefaultPlacement overwritten on user

this fixes cases where it's possible to update a user and the
DefaultPlacement field gets overwritten to the zero value.

it also adds UpdateDefaultPlacement which can be used to set
DefaultPlacement directly. This is needed for the geofencing
endpoints in satellite admin to set the DefaultPlacement back
to zero to delete geofencing for a user.

Change-Id: If2c798dabfa6773ed6023fb8257bf00ec7bc2e68
This commit is contained in:
Sean Harvey 2023-10-04 14:31:39 +13:00
parent 6304046e80
commit bc7f621073
No known key found for this signature in database
GPG Key ID: D917C00695250311
4 changed files with 78 additions and 12 deletions

View File

@ -834,12 +834,7 @@ func (server *Server) setGeofenceForUser(w http.ResponseWriter, r *http.Request,
return
}
err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{
Email: &user.Email,
DefaultPlacement: placement,
})
if err != nil {
if err = server.db.Console().Users().UpdateDefaultPlacement(ctx, user.ID, placement); err != nil {
sendJSONError(w, "unable to set geofence for user",
err.Error(), http.StatusInternalServerError)
return

View File

@ -46,6 +46,8 @@ type Users interface {
UpdateUserAgent(ctx context.Context, id uuid.UUID, userAgent []byte) error
// UpdateUserProjectLimits is a method to update the user's usage limits for new projects.
UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, limits UsageLimits) error
// UpdateDefaultPlacement is a method to update the user's default placement for new projects.
UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) error
// GetProjectLimit is a method to get the users project limit
GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error)
// GetUserProjectLimits is a method to get the users storage and bandwidth limits for new projects.

View File

@ -180,6 +180,10 @@ func (users *users) Insert(ctx context.Context, user *console.User) (_ *console.
optional.SignupCaptcha = dbx.User_SignupCaptcha(*user.SignupCaptcha)
}
if user.DefaultPlacement > 0 {
optional.DefaultPlacement = dbx.User_DefaultPlacement(int(user.DefaultPlacement))
}
createdUser, err := users.db.Create_User(ctx,
dbx.User_Id(user.ID[:]),
dbx.User_Email(user.Email),
@ -339,6 +343,21 @@ func (users *users) UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, l
return err
}
// UpdateDefaultPlacement is a method to update the user's default placement for new projects.
func (users *users) UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) (err error) {
defer mon.Task()(&ctx)(&err)
_, err = users.db.Update_User_By_Id(
ctx,
dbx.User_Id(id[:]),
dbx.User_Update_Fields{
DefaultPlacement: dbx.User_DefaultPlacement(int(placement)),
},
)
return err
}
// GetProjectLimit is a method to get the users project limit.
func (users *users) GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error) {
defer mon.Task()(&ctx)(&err)
@ -525,7 +544,10 @@ func toUpdateUser(request console.UpdateUserRequest) (*dbx.User_Update_Fields, e
update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration(**request.LoginLockoutExpiration)
}
}
update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement))
if request.DefaultPlacement > 0 {
update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement))
}
return &update, nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"storj.io/common/storj"
"storj.io/common/testcontext"
"storj.io/common/testrand"
"storj.io/common/uuid"
@ -63,11 +64,12 @@ func TestUpdateUser(t *testing.T) {
users := db.Console().Users()
id := testrand.UUID()
u, err := users.Insert(ctx, &console.User{
ID: id,
FullName: "testFullName",
ShortName: "testShortName",
Email: "test@storj.test",
PasswordHash: []byte("testPasswordHash"),
ID: id,
FullName: "testFullName",
ShortName: "testShortName",
Email: "test@storj.test",
PasswordHash: []byte("testPasswordHash"),
DefaultPlacement: 12,
})
require.NoError(t, err)
@ -85,6 +87,7 @@ func TestUpdateUser(t *testing.T) {
MFARecoveryCodes: []string{"code1", "code2"},
FailedLoginCount: 1,
LoginLockoutExpiration: time.Now().Truncate(time.Second),
DefaultPlacement: 13,
}
require.NotEqual(t, u.FullName, newInfo.FullName)
@ -100,6 +103,7 @@ func TestUpdateUser(t *testing.T) {
require.NotEqual(t, u.MFARecoveryCodes, newInfo.MFARecoveryCodes)
require.NotEqual(t, u.FailedLoginCount, newInfo.FailedLoginCount)
require.NotEqual(t, u.LoginLockoutExpiration, newInfo.LoginLockoutExpiration)
require.NotEqual(t, u.DefaultPlacement, newInfo.DefaultPlacement)
// update just fullname
updateReq := console.UpdateUserRequest{
@ -285,6 +289,21 @@ func TestUpdateUser(t *testing.T) {
u.LoginLockoutExpiration = newInfo.LoginLockoutExpiration
require.Equal(t, u, updatedUser)
// update just the placement
defaultPlacement := &newInfo.DefaultPlacement
updateReq = console.UpdateUserRequest{
DefaultPlacement: *defaultPlacement,
}
err = users.Update(ctx, id, updateReq)
require.NoError(t, err)
updatedUser, err = users.Get(ctx, id)
require.NoError(t, err)
u.DefaultPlacement = newInfo.DefaultPlacement
require.Equal(t, u, updatedUser)
})
}
@ -312,6 +331,34 @@ func TestUpdateUserProjectLimits(t *testing.T) {
})
}
func TestUpdateDefaultPlacement(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
usersRepo := db.Console().Users()
user, err := usersRepo.Insert(ctx, &console.User{
ID: testrand.UUID(),
FullName: "User",
Email: "test@mail.test",
PasswordHash: []byte("123a123"),
})
require.NoError(t, err)
err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, 12)
require.NoError(t, err)
user, err = usersRepo.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, storj.PlacementConstraint(12), user.DefaultPlacement)
err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, storj.EveryCountry)
require.NoError(t, err)
user, err = usersRepo.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, storj.EveryCountry, user.DefaultPlacement)
})
}
func TestUserSettings(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
users := db.Console().Users()