diff --git a/satellite/console/consoleweb/consoleapi/auth.go b/satellite/console/consoleweb/consoleapi/auth.go index 40e08d462..122d76045 100644 --- a/satellite/console/consoleweb/consoleapi/auth.go +++ b/satellite/console/consoleweb/consoleapi/auth.go @@ -125,6 +125,22 @@ func (a *Auth) Token(w http.ResponseWriter, r *http.Request) { } } +// getSessionID gets the session ID from the request. +func (a *Auth) getSessionID(r *http.Request) (id uuid.UUID, err error) { + + tokenInfo, err := a.cookieAuth.GetToken(r) + if err != nil { + return uuid.UUID{}, err + } + + sessionID, err := uuid.FromBytes(tokenInfo.Token.Payload) + if err != nil { + return uuid.UUID{}, err + } + + return sessionID, nil +} + // Logout removes auth cookie. func (a *Auth) Logout(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -132,19 +148,13 @@ func (a *Auth) Logout(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - tokenInfo, err := a.cookieAuth.GetToken(r) + sessionID, err := a.getSessionID(r) if err != nil { a.serveJSONError(w, err) return } - id, err := uuid.FromBytes(tokenInfo.Token.Payload) - if err != nil { - a.serveJSONError(w, err) - return - } - - err = a.service.DeleteSession(ctx, id) + err = a.service.DeleteSession(ctx, sessionID) if err != nil { a.serveJSONError(w, err) return @@ -686,6 +696,24 @@ func (a *Auth) EnableUserMFA(w http.ResponseWriter, r *http.Request) { a.serveJSONError(w, err) return } + + sessionID, err := a.getSessionID(r) + if err != nil { + a.serveJSONError(w, err) + return + } + + consoleUser, err := console.GetUser(ctx) + if err != nil { + a.serveJSONError(w, err) + return + } + + err = a.service.DeleteAllSessionsByUserIDExcept(ctx, consoleUser.ID, sessionID) + if err != nil { + a.serveJSONError(w, err) + return + } } // DisableUserMFA disables multi-factor authentication for the user. @@ -709,6 +737,24 @@ func (a *Auth) DisableUserMFA(w http.ResponseWriter, r *http.Request) { a.serveJSONError(w, err) return } + + sessionID, err := a.getSessionID(r) + if err != nil { + a.serveJSONError(w, err) + return + } + + consoleUser, err := console.GetUser(ctx) + if err != nil { + a.serveJSONError(w, err) + return + } + + err = a.service.DeleteAllSessionsByUserIDExcept(ctx, consoleUser.ID, sessionID) + if err != nil { + a.serveJSONError(w, err) + return + } } // GenerateMFASecretKey creates a new TOTP secret key for the user. diff --git a/satellite/console/service.go b/satellite/console/service.go index 8acd40ab6..a9f5a3ca2 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -2931,6 +2931,27 @@ func (s *Service) DeleteSession(ctx context.Context, sessionID uuid.UUID) (err e return Error.Wrap(s.store.WebappSessions().DeleteBySessionID(ctx, sessionID)) } +// DeleteAllSessionsByUserIDExcept removes all sessions except the specified session from the database. +func (s *Service) DeleteAllSessionsByUserIDExcept(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID) (err error) { + defer mon.Task()(&ctx)(&err) + + sessions, err := s.store.WebappSessions().GetAllByUserID(ctx, userID) + if err != nil { + return Error.Wrap(err) + } + + for _, session := range sessions { + if session.ID != sessionID { + err = s.DeleteSession(ctx, session.ID) + if err != nil { + return err + } + } + } + + return nil +} + // RefreshSession resets the expiration time of the session. func (s *Service) RefreshSession(ctx context.Context, sessionID uuid.UUID) (expiresAt time.Time, err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/console/service_test.go b/satellite/console/service_test.go index 167e7c381..4a2cccb1e 100644 --- a/satellite/console/service_test.go +++ b/satellite/console/service_test.go @@ -972,6 +972,54 @@ func TestSessionExpiration(t *testing.T) { }) } +func TestDeleteAllSessionsByUserIDExcept(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + sat := planet.Satellites[0] + service := sat.API.Console.Service + + user, err := sat.AddUser(ctx, console.CreateUser{ + FullName: "Test User", + Email: "test@mail.test", + }, 1) + require.NoError(t, err) + + // Session should be added to DB after token request + tokenInfo, err := service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName}) + require.NoError(t, err) + + _, err = service.TokenAuth(ctx, tokenInfo.Token, time.Now()) + require.NoError(t, err) + + sessionID, err := uuid.FromBytes(tokenInfo.Token.Payload) + require.NoError(t, err) + + _, err = sat.DB.Console().WebappSessions().GetBySessionID(ctx, sessionID) + require.NoError(t, err) + + // Session2 should be added to DB after token request + tokenInfo2, err := service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName}) + require.NoError(t, err) + + _, err = service.TokenAuth(ctx, tokenInfo2.Token, time.Now()) + require.NoError(t, err) + + sessionID2, err := uuid.FromBytes(tokenInfo2.Token.Payload) + require.NoError(t, err) + + _, err = sat.DB.Console().WebappSessions().GetBySessionID(ctx, sessionID2) + require.NoError(t, err) + + // Session2 should be removed from DB after calling DeleteAllSessionByUserIDExcept with Session1 + err = service.DeleteAllSessionsByUserIDExcept(ctx, user.ID, sessionID) + require.NoError(t, err) + + _, err = sat.DB.Console().WebappSessions().GetBySessionID(ctx, sessionID2) + require.ErrorIs(t, sql.ErrNoRows, err) + }) +} + func TestPaymentsWalletPayments(t *testing.T) { testplanet.Run(t, testplanet.Config{ SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,