From bc7f6210737d5f098fd25f544323580fd8396ef0 Mon Sep 17 00:00:00 2001 From: Sean Harvey Date: Wed, 4 Oct 2023 14:31:39 +1300 Subject: [PATCH] satellite/satellitedb: fix DefaultPlacement overwritten on user this fixes cases where it's possible to update a user and the DefaultPlacement field gets overwritten to the zero value. it also adds UpdateDefaultPlacement which can be used to set DefaultPlacement directly. This is needed for the geofencing endpoints in satellite admin to set the DefaultPlacement back to zero to delete geofencing for a user. Change-Id: If2c798dabfa6773ed6023fb8257bf00ec7bc2e68 --- satellite/admin/user.go | 7 +--- satellite/console/users.go | 2 + satellite/satellitedb/users.go | 24 +++++++++++- satellite/satellitedb/users_test.go | 57 ++++++++++++++++++++++++++--- 4 files changed, 78 insertions(+), 12 deletions(-) diff --git a/satellite/admin/user.go b/satellite/admin/user.go index 72c73cb85..ca87d56b2 100644 --- a/satellite/admin/user.go +++ b/satellite/admin/user.go @@ -834,12 +834,7 @@ func (server *Server) setGeofenceForUser(w http.ResponseWriter, r *http.Request, return } - err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{ - Email: &user.Email, - DefaultPlacement: placement, - }) - - if err != nil { + if err = server.db.Console().Users().UpdateDefaultPlacement(ctx, user.ID, placement); err != nil { sendJSONError(w, "unable to set geofence for user", err.Error(), http.StatusInternalServerError) return diff --git a/satellite/console/users.go b/satellite/console/users.go index cae4b9a88..b7ada3d69 100644 --- a/satellite/console/users.go +++ b/satellite/console/users.go @@ -46,6 +46,8 @@ type Users interface { UpdateUserAgent(ctx context.Context, id uuid.UUID, userAgent []byte) error // UpdateUserProjectLimits is a method to update the user's usage limits for new projects. UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, limits UsageLimits) error + // UpdateDefaultPlacement is a method to update the user's default placement for new projects. + UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) error // GetProjectLimit is a method to get the users project limit GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error) // GetUserProjectLimits is a method to get the users storage and bandwidth limits for new projects. diff --git a/satellite/satellitedb/users.go b/satellite/satellitedb/users.go index d7bfecb6e..5f460ae04 100644 --- a/satellite/satellitedb/users.go +++ b/satellite/satellitedb/users.go @@ -180,6 +180,10 @@ func (users *users) Insert(ctx context.Context, user *console.User) (_ *console. optional.SignupCaptcha = dbx.User_SignupCaptcha(*user.SignupCaptcha) } + if user.DefaultPlacement > 0 { + optional.DefaultPlacement = dbx.User_DefaultPlacement(int(user.DefaultPlacement)) + } + createdUser, err := users.db.Create_User(ctx, dbx.User_Id(user.ID[:]), dbx.User_Email(user.Email), @@ -339,6 +343,21 @@ func (users *users) UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, l return err } +// UpdateDefaultPlacement is a method to update the user's default placement for new projects. +func (users *users) UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) (err error) { + defer mon.Task()(&ctx)(&err) + + _, err = users.db.Update_User_By_Id( + ctx, + dbx.User_Id(id[:]), + dbx.User_Update_Fields{ + DefaultPlacement: dbx.User_DefaultPlacement(int(placement)), + }, + ) + + return err +} + // GetProjectLimit is a method to get the users project limit. func (users *users) GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error) { defer mon.Task()(&ctx)(&err) @@ -525,7 +544,10 @@ func toUpdateUser(request console.UpdateUserRequest) (*dbx.User_Update_Fields, e update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration(**request.LoginLockoutExpiration) } } - update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement)) + + if request.DefaultPlacement > 0 { + update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement)) + } return &update, nil } diff --git a/satellite/satellitedb/users_test.go b/satellite/satellitedb/users_test.go index 0b5edbc28..cbca324ff 100644 --- a/satellite/satellitedb/users_test.go +++ b/satellite/satellitedb/users_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" + "storj.io/common/storj" "storj.io/common/testcontext" "storj.io/common/testrand" "storj.io/common/uuid" @@ -63,11 +64,12 @@ func TestUpdateUser(t *testing.T) { users := db.Console().Users() id := testrand.UUID() u, err := users.Insert(ctx, &console.User{ - ID: id, - FullName: "testFullName", - ShortName: "testShortName", - Email: "test@storj.test", - PasswordHash: []byte("testPasswordHash"), + ID: id, + FullName: "testFullName", + ShortName: "testShortName", + Email: "test@storj.test", + PasswordHash: []byte("testPasswordHash"), + DefaultPlacement: 12, }) require.NoError(t, err) @@ -85,6 +87,7 @@ func TestUpdateUser(t *testing.T) { MFARecoveryCodes: []string{"code1", "code2"}, FailedLoginCount: 1, LoginLockoutExpiration: time.Now().Truncate(time.Second), + DefaultPlacement: 13, } require.NotEqual(t, u.FullName, newInfo.FullName) @@ -100,6 +103,7 @@ func TestUpdateUser(t *testing.T) { require.NotEqual(t, u.MFARecoveryCodes, newInfo.MFARecoveryCodes) require.NotEqual(t, u.FailedLoginCount, newInfo.FailedLoginCount) require.NotEqual(t, u.LoginLockoutExpiration, newInfo.LoginLockoutExpiration) + require.NotEqual(t, u.DefaultPlacement, newInfo.DefaultPlacement) // update just fullname updateReq := console.UpdateUserRequest{ @@ -285,6 +289,21 @@ func TestUpdateUser(t *testing.T) { u.LoginLockoutExpiration = newInfo.LoginLockoutExpiration require.Equal(t, u, updatedUser) + + // update just the placement + defaultPlacement := &newInfo.DefaultPlacement + updateReq = console.UpdateUserRequest{ + DefaultPlacement: *defaultPlacement, + } + + err = users.Update(ctx, id, updateReq) + require.NoError(t, err) + + updatedUser, err = users.Get(ctx, id) + require.NoError(t, err) + + u.DefaultPlacement = newInfo.DefaultPlacement + require.Equal(t, u, updatedUser) }) } @@ -312,6 +331,34 @@ func TestUpdateUserProjectLimits(t *testing.T) { }) } +func TestUpdateDefaultPlacement(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + usersRepo := db.Console().Users() + + user, err := usersRepo.Insert(ctx, &console.User{ + ID: testrand.UUID(), + FullName: "User", + Email: "test@mail.test", + PasswordHash: []byte("123a123"), + }) + require.NoError(t, err) + + err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, 12) + require.NoError(t, err) + + user, err = usersRepo.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, storj.PlacementConstraint(12), user.DefaultPlacement) + + err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, storj.EveryCountry) + require.NoError(t, err) + + user, err = usersRepo.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, storj.EveryCountry, user.DefaultPlacement) + }) +} + func TestUserSettings(t *testing.T) { satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { users := db.Console().Users()