diff --git a/satellite/console/consoleauth/sessions.go b/satellite/console/consoleauth/sessions.go index e3b736d32..0ba84fc38 100644 --- a/satellite/console/consoleauth/sessions.go +++ b/satellite/console/consoleauth/sessions.go @@ -23,7 +23,9 @@ type WebappSessions interface { // DeleteAllByUserID deletes all webapp sessions by user ID. DeleteAllByUserID(ctx context.Context, userID uuid.UUID) (int64, error) // UpdateExpiration updates the expiration time of the session. - UpdateExpiration(ctx context.Context, sessionID uuid.UUID, expiresAt time.Time) (err error) + UpdateExpiration(ctx context.Context, sessionID uuid.UUID, expiresAt time.Time) error + // DeleteExpired deletes all sessions that have expired before the provided timestamp. + DeleteExpired(ctx context.Context, now time.Time, asOfSystemTimeInterval time.Duration, pageSize int) error } // WebappSession represents a session on the satellite web app. diff --git a/satellite/console/dbcleanup/chore.go b/satellite/console/dbcleanup/chore.go index 70dcc6d48..01ec279f7 100644 --- a/satellite/console/dbcleanup/chore.go +++ b/satellite/console/dbcleanup/chore.go @@ -54,6 +54,12 @@ func (chore *Chore) Run(ctx context.Context) (err error) { if err != nil { chore.log.Error("Error deleting unverified users", zap.Error(err)) } + + err = chore.db.WebappSessions().DeleteExpired(ctx, time.Now(), chore.config.AsOfSystemTimeInterval, chore.config.PageSize) + if err != nil { + chore.log.Error("Error deleting expired webapp sessions", zap.Error(err)) + } + return nil }) } diff --git a/satellite/satellitedb/consoledb.go b/satellite/satellitedb/consoledb.go index c76b798c7..ef66782f9 100644 --- a/satellite/satellitedb/consoledb.go +++ b/satellite/satellitedb/consoledb.go @@ -78,7 +78,7 @@ func (db *ConsoleDB) ResetPasswordTokens() console.ResetPasswordTokens { // WebappSessions is a getter for WebappSessions repository. func (db *ConsoleDB) WebappSessions() consoleauth.WebappSessions { - return &webappSessions{db.methods} + return &webappSessions{db.db} } // AccountFreezeEvents is a getter for AccountFreezeEvents repository. diff --git a/satellite/satellitedb/webappsessions.go b/satellite/satellitedb/webappsessions.go index b5c61eace..5beabd4bd 100644 --- a/satellite/satellitedb/webappsessions.go +++ b/satellite/satellitedb/webappsessions.go @@ -5,6 +5,8 @@ package satellitedb import ( "context" + "database/sql" + "errors" "time" "storj.io/common/uuid" @@ -16,7 +18,7 @@ import ( var _ consoleauth.WebappSessions = (*webappSessions)(nil) type webappSessions struct { - db dbx.Methods + db *satelliteDB } // Create creates a webapp session and returns the session info. @@ -91,6 +93,75 @@ func (db *webappSessions) DeleteAllByUserID(ctx context.Context, userID uuid.UUI return db.db.Delete_WebappSession_By_UserId(ctx, dbx.WebappSession_UserId(userID.Bytes())) } +// DeleteExpired deletes all sessions that have expired before the provided timestamp. +func (db *webappSessions) DeleteExpired(ctx context.Context, now time.Time, asOfSystemTimeInterval time.Duration, pageSize int) (err error) { + defer mon.Task()(&ctx)(&err) + + if pageSize <= 0 { + return Error.New("expected page size to be positive; got %d", pageSize) + } + + var pageCursor, pageEnd uuid.UUID + aost := db.db.impl.AsOfSystemInterval(asOfSystemTimeInterval) + for { + // Select the ID beginning this page of records + err := db.db.QueryRowContext(ctx, ` + SELECT id FROM webapp_sessions + `+aost+` + WHERE id > $1 AND expires_at < $2 + ORDER BY id LIMIT 1 + `, pageCursor, now).Scan(&pageCursor) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil + } + return Error.Wrap(err) + } + + // Select the ID ending this page of records + err = db.db.QueryRowContext(ctx, ` + SELECT id FROM webapp_sessions + `+aost+` + WHERE id > $1 + ORDER BY id LIMIT 1 OFFSET $2 + `, pageCursor, pageSize).Scan(&pageEnd) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return Error.Wrap(err) + } + // Since this is the last page, we want to return all remaining records + _, err = db.db.ExecContext(ctx, ` + DELETE FROM webapp_sessions + WHERE id IN ( + SELECT id FROM webapp_sessions + `+aost+` + WHERE id >= $1 AND expires_at < $2 + ORDER BY id + ) + `, pageCursor, now) + return Error.Wrap(err) + } + + // Delete all expired records in the range between the beginning and ending IDs + _, err = db.db.ExecContext(ctx, ` + DELETE FROM webapp_sessions + WHERE id IN ( + SELECT id FROM webapp_sessions + `+aost+` + WHERE id BETWEEN $1 AND $2 + AND expires_at < $3 + ORDER BY id + ) + `, pageCursor, pageEnd, now) + if err != nil { + return Error.Wrap(err) + } + + // Advance the cursor to the next page + pageCursor = pageEnd + } +} + func getSessionFromDBX(dbxSession *dbx.WebappSession) (consoleauth.WebappSession, error) { id, err := uuid.FromBytes(dbxSession.Id) if err != nil { diff --git a/satellite/satellitedb/webappsessions_test.go b/satellite/satellitedb/webappsessions_test.go index c59e782d0..ae6b81e36 100644 --- a/satellite/satellitedb/webappsessions_test.go +++ b/satellite/satellitedb/webappsessions_test.go @@ -4,6 +4,7 @@ package satellitedb_test import ( + "database/sql" "testing" "time" @@ -186,3 +187,26 @@ func TestWebappSessionsDeleteAllByUserID(t *testing.T) { require.Len(t, allSessions, 0) }) } + +func TestDeleteExpired(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + sessionsDB := db.Console().WebappSessions() + now := time.Now() + + // Only positive page sizes should be allowed. + require.Error(t, sessionsDB.DeleteExpired(ctx, time.Time{}, 0, 0)) + require.Error(t, sessionsDB.DeleteExpired(ctx, time.Time{}, 0, -1)) + + newSession, err := sessionsDB.Create(ctx, testrand.UUID(), testrand.UUID(), "", "", now.Add(time.Second)) + require.NoError(t, err) + oldSession, err := sessionsDB.Create(ctx, testrand.UUID(), testrand.UUID(), "", "", now.Add(-time.Second)) + require.NoError(t, err) + require.NoError(t, sessionsDB.DeleteExpired(ctx, now, 0, 1)) + + // Ensure that the old session record was deleted and the other remains. + _, err = sessionsDB.GetBySessionID(ctx, oldSession.ID) + require.ErrorIs(t, err, sql.ErrNoRows) + _, err = sessionsDB.GetBySessionID(ctx, newSession.ID) + require.NoError(t, err) + }) +}