diff --git a/satellite/admin/user.go b/satellite/admin/user.go index 72c73cb85..ca87d56b2 100644 --- a/satellite/admin/user.go +++ b/satellite/admin/user.go @@ -834,12 +834,7 @@ func (server *Server) setGeofenceForUser(w http.ResponseWriter, r *http.Request, return } - err = server.db.Console().Users().Update(ctx, user.ID, console.UpdateUserRequest{ - Email: &user.Email, - DefaultPlacement: placement, - }) - - if err != nil { + if err = server.db.Console().Users().UpdateDefaultPlacement(ctx, user.ID, placement); err != nil { sendJSONError(w, "unable to set geofence for user", err.Error(), http.StatusInternalServerError) return diff --git a/satellite/console/users.go b/satellite/console/users.go index cae4b9a88..b7ada3d69 100644 --- a/satellite/console/users.go +++ b/satellite/console/users.go @@ -46,6 +46,8 @@ type Users interface { UpdateUserAgent(ctx context.Context, id uuid.UUID, userAgent []byte) error // UpdateUserProjectLimits is a method to update the user's usage limits for new projects. UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, limits UsageLimits) error + // UpdateDefaultPlacement is a method to update the user's default placement for new projects. + UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) error // GetProjectLimit is a method to get the users project limit GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error) // GetUserProjectLimits is a method to get the users storage and bandwidth limits for new projects. diff --git a/satellite/satellitedb/users.go b/satellite/satellitedb/users.go index d7bfecb6e..5f460ae04 100644 --- a/satellite/satellitedb/users.go +++ b/satellite/satellitedb/users.go @@ -180,6 +180,10 @@ func (users *users) Insert(ctx context.Context, user *console.User) (_ *console. optional.SignupCaptcha = dbx.User_SignupCaptcha(*user.SignupCaptcha) } + if user.DefaultPlacement > 0 { + optional.DefaultPlacement = dbx.User_DefaultPlacement(int(user.DefaultPlacement)) + } + createdUser, err := users.db.Create_User(ctx, dbx.User_Id(user.ID[:]), dbx.User_Email(user.Email), @@ -339,6 +343,21 @@ func (users *users) UpdateUserProjectLimits(ctx context.Context, id uuid.UUID, l return err } +// UpdateDefaultPlacement is a method to update the user's default placement for new projects. +func (users *users) UpdateDefaultPlacement(ctx context.Context, id uuid.UUID, placement storj.PlacementConstraint) (err error) { + defer mon.Task()(&ctx)(&err) + + _, err = users.db.Update_User_By_Id( + ctx, + dbx.User_Id(id[:]), + dbx.User_Update_Fields{ + DefaultPlacement: dbx.User_DefaultPlacement(int(placement)), + }, + ) + + return err +} + // GetProjectLimit is a method to get the users project limit. func (users *users) GetProjectLimit(ctx context.Context, id uuid.UUID) (limit int, err error) { defer mon.Task()(&ctx)(&err) @@ -525,7 +544,10 @@ func toUpdateUser(request console.UpdateUserRequest) (*dbx.User_Update_Fields, e update.LoginLockoutExpiration = dbx.User_LoginLockoutExpiration(**request.LoginLockoutExpiration) } } - update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement)) + + if request.DefaultPlacement > 0 { + update.DefaultPlacement = dbx.User_DefaultPlacement(int(request.DefaultPlacement)) + } return &update, nil } diff --git a/satellite/satellitedb/users_test.go b/satellite/satellitedb/users_test.go index 0b5edbc28..cbca324ff 100644 --- a/satellite/satellitedb/users_test.go +++ b/satellite/satellitedb/users_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" + "storj.io/common/storj" "storj.io/common/testcontext" "storj.io/common/testrand" "storj.io/common/uuid" @@ -63,11 +64,12 @@ func TestUpdateUser(t *testing.T) { users := db.Console().Users() id := testrand.UUID() u, err := users.Insert(ctx, &console.User{ - ID: id, - FullName: "testFullName", - ShortName: "testShortName", - Email: "test@storj.test", - PasswordHash: []byte("testPasswordHash"), + ID: id, + FullName: "testFullName", + ShortName: "testShortName", + Email: "test@storj.test", + PasswordHash: []byte("testPasswordHash"), + DefaultPlacement: 12, }) require.NoError(t, err) @@ -85,6 +87,7 @@ func TestUpdateUser(t *testing.T) { MFARecoveryCodes: []string{"code1", "code2"}, FailedLoginCount: 1, LoginLockoutExpiration: time.Now().Truncate(time.Second), + DefaultPlacement: 13, } require.NotEqual(t, u.FullName, newInfo.FullName) @@ -100,6 +103,7 @@ func TestUpdateUser(t *testing.T) { require.NotEqual(t, u.MFARecoveryCodes, newInfo.MFARecoveryCodes) require.NotEqual(t, u.FailedLoginCount, newInfo.FailedLoginCount) require.NotEqual(t, u.LoginLockoutExpiration, newInfo.LoginLockoutExpiration) + require.NotEqual(t, u.DefaultPlacement, newInfo.DefaultPlacement) // update just fullname updateReq := console.UpdateUserRequest{ @@ -285,6 +289,21 @@ func TestUpdateUser(t *testing.T) { u.LoginLockoutExpiration = newInfo.LoginLockoutExpiration require.Equal(t, u, updatedUser) + + // update just the placement + defaultPlacement := &newInfo.DefaultPlacement + updateReq = console.UpdateUserRequest{ + DefaultPlacement: *defaultPlacement, + } + + err = users.Update(ctx, id, updateReq) + require.NoError(t, err) + + updatedUser, err = users.Get(ctx, id) + require.NoError(t, err) + + u.DefaultPlacement = newInfo.DefaultPlacement + require.Equal(t, u, updatedUser) }) } @@ -312,6 +331,34 @@ func TestUpdateUserProjectLimits(t *testing.T) { }) } +func TestUpdateDefaultPlacement(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + usersRepo := db.Console().Users() + + user, err := usersRepo.Insert(ctx, &console.User{ + ID: testrand.UUID(), + FullName: "User", + Email: "test@mail.test", + PasswordHash: []byte("123a123"), + }) + require.NoError(t, err) + + err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, 12) + require.NoError(t, err) + + user, err = usersRepo.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, storj.PlacementConstraint(12), user.DefaultPlacement) + + err = usersRepo.UpdateDefaultPlacement(ctx, user.ID, storj.EveryCountry) + require.NoError(t, err) + + user, err = usersRepo.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, storj.EveryCountry, user.DefaultPlacement) + }) +} + func TestUserSettings(t *testing.T) { satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { users := db.Console().Users()