From 24370964abf701f5139f6cdfb063cebc9587ecb7 Mon Sep 17 00:00:00 2001 From: Wilfred Asomani Date: Mon, 13 Nov 2023 14:03:03 +0000 Subject: [PATCH] satellite/{console/payment}: wrap freeze code in transactions This change wraps account freeze code in DB transactions to prevent freeze inconsistencies resulting from errors that happen in the process of freezing accounts. Change-Id: Ib67fb30dc33248413d3057ceeac5c2f410f551d5 --- satellite/admin.go | 4 +- satellite/admin/user.go | 30 +- satellite/api.go | 4 +- satellite/console/accountfreezes.go | 857 +++++++++--------- satellite/console/accountfreezes_test.go | 18 +- satellite/core.go | 6 +- .../payments/accountfreeze/chore_test.go | 7 +- satellite/payments/billing/chore_test.go | 2 +- 8 files changed, 466 insertions(+), 462 deletions(-) diff --git a/satellite/admin.go b/satellite/admin.go index 7f6d90b25..53bc98eff 100644 --- a/satellite/admin.go +++ b/satellite/admin.go @@ -178,9 +178,7 @@ func NewAdmin(log *zap.Logger, full *identity.FullIdentity, db DB, metabaseDB *m } peer.FreezeAccounts.Service = console.NewAccountFreezeService( - db.Console().AccountFreezeEvents(), - db.Console().Users(), - db.Console().Projects(), + db.Console(), peer.Analytics.Service, config.Console.AccountFreeze, ) diff --git a/satellite/admin/user.go b/satellite/admin/user.go index 713e8b323..4b4e6319a 100644 --- a/satellite/admin/user.go +++ b/satellite/admin/user.go @@ -720,14 +720,8 @@ func (server *Server) billingUnfreezeUser(w http.ResponseWriter, r *http.Request err = server.freezeAccounts.BillingUnfreezeUser(ctx, u.ID) if err != nil { - if errors.Is(err, console.ErrFreezeUserStatusUpdate) { - sendJSONError(w, "User unfrozen but failed to change user status to active. "+ - "Run the command again, but if the error persists, intervene manually.", - err.Error(), http.StatusInternalServerError) - } else { - sendJSONError(w, "failed to violation unfreeze user", - err.Error(), http.StatusInternalServerError) - } + sendJSONError(w, "failed to billing unfreeze user", + err.Error(), http.StatusInternalServerError) return } } @@ -785,14 +779,8 @@ func (server *Server) violationFreezeUser(w http.ResponseWriter, r *http.Request err = server.freezeAccounts.ViolationFreezeUser(ctx, u.ID) if err != nil { - if errors.Is(err, console.ErrFreezeUserStatusUpdate) { - sendJSONError(w, "User frozen but failed to change user status to Pending Deletion. "+ - "Run the command again, but if the error persists, intervene manually.", - err.Error(), http.StatusInternalServerError) - } else { - sendJSONError(w, "failed to violation freeze user", - err.Error(), http.StatusInternalServerError) - } + sendJSONError(w, "failed to violation freeze user", + err.Error(), http.StatusInternalServerError) return } @@ -833,14 +821,8 @@ func (server *Server) violationUnfreezeUser(w http.ResponseWriter, r *http.Reque err = server.freezeAccounts.ViolationUnfreezeUser(ctx, u.ID) if err != nil { - if errors.Is(err, console.ErrFreezeUserStatusUpdate) { - sendJSONError(w, "User unfrozen but failed to change user status to active. "+ - "Run the command again, but if the error persists, intervene manually.", - err.Error(), http.StatusInternalServerError) - } else { - sendJSONError(w, "failed to violation unfreeze user", - err.Error(), http.StatusInternalServerError) - } + sendJSONError(w, "failed to violation unfreeze user", + err.Error(), http.StatusInternalServerError) return } } diff --git a/satellite/api.go b/satellite/api.go index 0716d2ce7..971c0eb0c 100644 --- a/satellite/api.go +++ b/satellite/api.go @@ -597,9 +597,7 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB, } accountFreezeService := console.NewAccountFreezeService( - db.Console().AccountFreezeEvents(), - db.Console().Users(), - db.Console().Projects(), + db.Console(), peer.Analytics.Service, consoleConfig.AccountFreeze, ) diff --git a/satellite/console/accountfreezes.go b/satellite/console/accountfreezes.go index 67d6c5f9c..ec5b8b96d 100644 --- a/satellite/console/accountfreezes.go +++ b/satellite/console/accountfreezes.go @@ -18,10 +18,6 @@ import ( // ErrAccountFreeze is the class for errors that occur during operation of the account freeze service. var ErrAccountFreeze = errs.Class("account freeze service") -// ErrFreezeUserStatusUpdate is error returned if updating the user status as part of violation (un)freeze -// fails. -var ErrFreezeUserStatusUpdate = errs.New("user status update failed") - // AccountFreezeEvents exposes methods to manage the account freeze events table in database. // // architecture: Database @@ -116,19 +112,17 @@ type AccountFreezeConfig struct { // AccountFreezeService encapsulates operations concerning account freezes. type AccountFreezeService struct { + store DB freezeEventsDB AccountFreezeEvents - usersDB Users - projectsDB Projects tracker analytics.FreezeTracker config AccountFreezeConfig } // NewAccountFreezeService creates a new account freeze service. -func NewAccountFreezeService(freezeEventsDB AccountFreezeEvents, usersDB Users, projectsDB Projects, tracker analytics.FreezeTracker, config AccountFreezeConfig) *AccountFreezeService { +func NewAccountFreezeService(db DB, tracker analytics.FreezeTracker, config AccountFreezeConfig) *AccountFreezeService { return &AccountFreezeService{ - freezeEventsDB: freezeEventsDB, - usersDB: usersDB, - projectsDB: projectsDB, + store: db, + freezeEventsDB: db.AccountFreezeEvents(), tracker: tracker, config: config, } @@ -164,499 +158,533 @@ func (s *AccountFreezeService) IsUserFrozen(ctx context.Context, userID uuid.UUI func (s *AccountFreezeService) BillingFreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - freezes, err := s.freezeEventsDB.GetAll(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - if freezes.ViolationFreeze != nil { - return ErrAccountFreeze.New("User is already frozen due to ToS violation") - } - if freezes.LegalFreeze != nil { - return ErrAccountFreeze.New("User is already frozen for legal review") - } - - userLimits := UsageLimits{ - Storage: user.ProjectStorageLimit, - Bandwidth: user.ProjectBandwidthLimit, - Segment: user.ProjectSegmentLimit, - } - - daysTillEscalation := int(s.config.BillingFreezeGracePeriod.Hours() / 24) - billingFreeze := freezes.BillingFreeze - if billingFreeze == nil { - billingFreeze = &AccountFreezeEvent{ - UserID: userID, - Type: BillingFreeze, - DaysTillEscalation: &daysTillEscalation, - Limits: &AccountFreezeEventLimits{ - User: userLimits, - Projects: make(map[uuid.UUID]UsageLimits), - }, - } - } - - // If user limits have been zeroed already, we should not override what is in the freeze table. - if userLimits != (UsageLimits{}) { - billingFreeze.Limits.User = userLimits - } - - projects, err := s.projectsDB.GetOwn(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - for _, p := range projects { - projLimits := UsageLimits{} - if p.StorageLimit != nil { - projLimits.Storage = p.StorageLimit.Int64() - } - if p.BandwidthLimit != nil { - projLimits.Bandwidth = p.BandwidthLimit.Int64() - } - if p.SegmentLimit != nil { - projLimits.Segment = *p.SegmentLimit - } - // If project limits have been zeroed already, we should not override what is in the freeze table. - if projLimits != (UsageLimits{}) { - billingFreeze.Limits.Projects[p.ID] = projLimits - } - } - - _, err = s.freezeEventsDB.Upsert(ctx, billingFreeze) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, UsageLimits{}) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - for _, proj := range projects { - err := s.projectsDB.UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) if err != nil { return ErrAccountFreeze.Wrap(err) } - } - if freezes.BillingWarning != nil { - err = s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingWarning) + freezes, err := tx.AccountFreezeEvents().GetAll(ctx, userID) if err != nil { return ErrAccountFreeze.Wrap(err) } - } + if freezes.ViolationFreeze != nil { + return ErrAccountFreeze.New("User is already frozen due to ToS violation") + } + if freezes.LegalFreeze != nil { + return ErrAccountFreeze.New("User is already frozen for legal review") + } - s.tracker.TrackAccountFrozen(userID, user.Email) - return nil + userLimits := UsageLimits{ + Storage: user.ProjectStorageLimit, + Bandwidth: user.ProjectBandwidthLimit, + Segment: user.ProjectSegmentLimit, + } + + daysTillEscalation := int(s.config.BillingFreezeGracePeriod.Hours() / 24) + billingFreeze := freezes.BillingFreeze + if billingFreeze == nil { + billingFreeze = &AccountFreezeEvent{ + UserID: userID, + Type: BillingFreeze, + DaysTillEscalation: &daysTillEscalation, + Limits: &AccountFreezeEventLimits{ + User: userLimits, + Projects: make(map[uuid.UUID]UsageLimits), + }, + } + } + + // If user limits have been zeroed already, we should not override what is in the freeze table. + if userLimits != (UsageLimits{}) { + billingFreeze.Limits.User = userLimits + } + + projects, err := tx.Projects().GetOwn(ctx, userID) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + for _, p := range projects { + projLimits := UsageLimits{} + if p.StorageLimit != nil { + projLimits.Storage = p.StorageLimit.Int64() + } + if p.BandwidthLimit != nil { + projLimits.Bandwidth = p.BandwidthLimit.Int64() + } + if p.SegmentLimit != nil { + projLimits.Segment = *p.SegmentLimit + } + // If project limits have been zeroed already, we should not override what is in the freeze table. + if projLimits != (UsageLimits{}) { + billingFreeze.Limits.Projects[p.ID] = projLimits + } + } + + _, err = tx.AccountFreezeEvents().Upsert(ctx, billingFreeze) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + + err = tx.Users().UpdateUserProjectLimits(ctx, userID, UsageLimits{}) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + + for _, proj := range projects { + err := tx.Projects().UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + } + + if freezes.BillingWarning != nil { + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingWarning) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + } + s.tracker.TrackAccountFrozen(userID, user.Email) + + return nil + }) + + return err } // BillingUnfreezeUser reverses the billing freeze placed on the user specified by the given ID. func (s *AccountFreezeService) BillingUnfreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - event, err := s.freezeEventsDB.Get(ctx, userID, BillingFreeze) - if errors.Is(err, sql.ErrNoRows) { - return ErrAccountFreeze.New("user is not frozen due to nonpayment of invoices") - } - - if event.Limits == nil { - return ErrAccountFreeze.New("freeze event limits are nil") - } - - for id, limits := range event.Limits.Projects { - err := s.projectsDB.UpdateUsageLimits(ctx, id, limits) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) if err != nil { - return ErrAccountFreeze.Wrap(err) + return err } - } - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, event.Limits.User) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + event, err := tx.AccountFreezeEvents().Get(ctx, userID, BillingFreeze) + if errors.Is(err, sql.ErrNoRows) { + return errs.New("user is not frozen due to nonpayment of invoices") + } - err = ErrAccountFreeze.Wrap(s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingFreeze)) - if err != nil { - return err - } + if event.Limits == nil { + return errs.New("freeze event limits are nil") + } - if user.Status == PendingDeletion { - status := Active - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, - }) + for id, limits := range event.Limits.Projects { + err := tx.Projects().UpdateUsageLimits(ctx, id, limits) + if err != nil { + return err + } + } + + err = tx.Users().UpdateUserProjectLimits(ctx, userID, event.Limits.User) if err != nil { - return ErrAccountFreeze.Wrap(errs.Combine(ErrFreezeUserStatusUpdate, err)) + return err } - } - s.tracker.TrackAccountUnfrozen(userID, user.Email) - return nil + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingFreeze) + if err != nil { + return err + } + + if user.Status == PendingDeletion { + status := Active + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return err + } + } + + s.tracker.TrackAccountUnfrozen(userID, user.Email) + + return nil + }) + + return ErrAccountFreeze.Wrap(err) } // BillingWarnUser adds a billing warning event to the freeze events table. func (s *AccountFreezeService) BillingWarnUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) + if err != nil { + return err + } - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + freezes, err := tx.AccountFreezeEvents().GetAll(ctx, userID) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } - freezes, err := s.freezeEventsDB.GetAll(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + if freezes.ViolationFreeze != nil || freezes.BillingFreeze != nil || freezes.LegalFreeze != nil { + return ErrAccountFreeze.New("User is already frozen") + } - if freezes.ViolationFreeze != nil || freezes.BillingFreeze != nil || freezes.LegalFreeze != nil { - return ErrAccountFreeze.New("User is already frozen") - } + if freezes.BillingWarning != nil { + return nil + } + + daysTillEscalation := int(s.config.BillingWarnGracePeriod.Hours() / 24) + _, err = tx.AccountFreezeEvents().Upsert(ctx, &AccountFreezeEvent{ + UserID: userID, + Type: BillingWarning, + DaysTillEscalation: &daysTillEscalation, + }) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + + s.tracker.TrackAccountFreezeWarning(userID, user.Email) - if freezes.BillingWarning != nil { return nil - } - - daysTillEscalation := int(s.config.BillingWarnGracePeriod.Hours() / 24) - _, err = s.freezeEventsDB.Upsert(ctx, &AccountFreezeEvent{ - UserID: userID, - Type: BillingWarning, - DaysTillEscalation: &daysTillEscalation, }) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - s.tracker.TrackAccountFreezeWarning(userID, user.Email) - return nil + return err } // BillingUnWarnUser reverses the warning placed on the user specified by the given ID. func (s *AccountFreezeService) BillingUnWarnUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) + if err != nil { + return err + } - _, err = s.freezeEventsDB.Get(ctx, userID, BillingWarning) - if errors.Is(err, sql.ErrNoRows) { - return ErrAccountFreeze.New("user is not warned") - } + _, err = tx.AccountFreezeEvents().Get(ctx, userID, BillingWarning) + if errors.Is(err, sql.ErrNoRows) { + return ErrAccountFreeze.New("user is not warned") + } - err = ErrAccountFreeze.Wrap(s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingWarning)) - if err != nil { - return err - } + err = ErrAccountFreeze.Wrap(tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingWarning)) + if err != nil { + return err + } - s.tracker.TrackAccountUnwarned(userID, user.Email) - return nil + s.tracker.TrackAccountUnwarned(userID, user.Email) + + return nil + }) + + return err } // ViolationFreezeUser freezes the user specified by the given ID due to ToS violation. func (s *AccountFreezeService) ViolationFreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) + if err != nil { + return err + } - freezes, err := s.freezeEventsDB.GetAll(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + freezes, err := tx.AccountFreezeEvents().GetAll(ctx, userID) + if err != nil { + return err + } - if freezes.LegalFreeze != nil { - return ErrAccountFreeze.New("User is already frozen for legal review") - } + if freezes.LegalFreeze != nil { + return errs.New("User is already frozen for legal review") + } - var limits *AccountFreezeEventLimits - if freezes.BillingFreeze != nil { - limits = freezes.BillingFreeze.Limits - } + var limits *AccountFreezeEventLimits + if freezes.BillingFreeze != nil { + limits = freezes.BillingFreeze.Limits + } - userLimits := UsageLimits{ - Storage: user.ProjectStorageLimit, - Bandwidth: user.ProjectBandwidthLimit, - Segment: user.ProjectSegmentLimit, - } + userLimits := UsageLimits{ + Storage: user.ProjectStorageLimit, + Bandwidth: user.ProjectBandwidthLimit, + Segment: user.ProjectSegmentLimit, + } - violationFreeze := freezes.ViolationFreeze - if violationFreeze == nil { - if limits == nil { - limits = &AccountFreezeEventLimits{ - User: userLimits, - Projects: make(map[uuid.UUID]UsageLimits), + violationFreeze := freezes.ViolationFreeze + if violationFreeze == nil { + if limits == nil { + limits = &AccountFreezeEventLimits{ + User: userLimits, + Projects: make(map[uuid.UUID]UsageLimits), + } + } + violationFreeze = &AccountFreezeEvent{ + UserID: userID, + Type: ViolationFreeze, + Limits: limits, } } - violationFreeze = &AccountFreezeEvent{ - UserID: userID, - Type: ViolationFreeze, - Limits: limits, - } - } - // If user limits have been zeroed already, we should not override what is in the freeze table. - if userLimits != (UsageLimits{}) { - violationFreeze.Limits.User = userLimits - } - - projects, err := s.projectsDB.GetOwn(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - for _, p := range projects { - projLimits := UsageLimits{} - if p.StorageLimit != nil { - projLimits.Storage = p.StorageLimit.Int64() + // If user limits have been zeroed already, we should not override what is in the freeze table. + if userLimits != (UsageLimits{}) { + violationFreeze.Limits.User = userLimits } - if p.BandwidthLimit != nil { - projLimits.Bandwidth = p.BandwidthLimit.Int64() - } - if p.SegmentLimit != nil { - projLimits.Segment = *p.SegmentLimit - } - // If project limits have been zeroed already, we should not override what is in the freeze table. - if projLimits != (UsageLimits{}) { - violationFreeze.Limits.Projects[p.ID] = projLimits - } - } - _, err = s.freezeEventsDB.Upsert(ctx, violationFreeze) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, UsageLimits{}) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - for _, proj := range projects { - err := s.projectsDB.UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + projects, err := tx.Projects().GetOwn(ctx, userID) if err != nil { - return ErrAccountFreeze.Wrap(err) + return err + } + for _, p := range projects { + projLimits := UsageLimits{} + if p.StorageLimit != nil { + projLimits.Storage = p.StorageLimit.Int64() + } + if p.BandwidthLimit != nil { + projLimits.Bandwidth = p.BandwidthLimit.Int64() + } + if p.SegmentLimit != nil { + projLimits.Segment = *p.SegmentLimit + } + // If project limits have been zeroed already, we should not override what is in the freeze table. + if projLimits != (UsageLimits{}) { + violationFreeze.Limits.Projects[p.ID] = projLimits + } } - } - status := PendingDeletion - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, + _, err = tx.AccountFreezeEvents().Upsert(ctx, violationFreeze) + if err != nil { + return err + } + + err = tx.Users().UpdateUserProjectLimits(ctx, userID, UsageLimits{}) + if err != nil { + return err + } + + for _, proj := range projects { + err := tx.Projects().UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + if err != nil { + return err + } + } + + status := PendingDeletion + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return err + } + + if freezes.BillingWarning != nil { + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingWarning) + if err != nil { + return err + } + } + + if freezes.BillingFreeze != nil { + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingFreeze) + if err != nil { + return err + } + } + + return nil }) - if err != nil { - return ErrAccountFreeze.Wrap(errs.Combine(ErrFreezeUserStatusUpdate, err)) - } - if freezes.BillingWarning != nil { - err = s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingWarning) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - } - - if freezes.BillingFreeze != nil { - err = s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingFreeze) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - } - - return nil + return ErrAccountFreeze.Wrap(err) } // ViolationUnfreezeUser reverses the violation freeze placed on the user specified by the given ID. func (s *AccountFreezeService) ViolationUnfreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - event, err := s.freezeEventsDB.Get(ctx, userID, ViolationFreeze) - if errors.Is(err, sql.ErrNoRows) { - return ErrAccountFreeze.New("user is not violation frozen") - } - - if event.Limits == nil { - return ErrAccountFreeze.New("freeze event limits are nil") - } - - for id, limits := range event.Limits.Projects { - err := s.projectsDB.UpdateUsageLimits(ctx, id, limits) - if err != nil { - return ErrAccountFreeze.Wrap(err) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + event, err := tx.AccountFreezeEvents().Get(ctx, userID, ViolationFreeze) + if errors.Is(err, sql.ErrNoRows) { + return errs.New("user is not violation frozen") } - } - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, event.Limits.User) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + if event.Limits == nil { + return errs.New("freeze event limits are nil") + } - err = s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, ViolationFreeze) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + for id, limits := range event.Limits.Projects { + err := tx.Projects().UpdateUsageLimits(ctx, id, limits) + if err != nil { + return err + } + } - status := Active - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, + err = tx.Users().UpdateUserProjectLimits(ctx, userID, event.Limits.User) + if err != nil { + return err + } + + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, ViolationFreeze) + if err != nil { + return err + } + + status := Active + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return err + } + + return nil }) - if err != nil { - return ErrAccountFreeze.Wrap(errs.Combine(ErrFreezeUserStatusUpdate, err)) - } - return nil + return ErrAccountFreeze.Wrap(err) } // LegalFreezeUser freezes the user specified by the given ID for legal review. func (s *AccountFreezeService) LegalFreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - freezes, err := s.freezeEventsDB.GetAll(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - if freezes.ViolationFreeze != nil { - return ErrAccountFreeze.New("User is already frozen due to ToS violation") - } - - userLimits := UsageLimits{ - Storage: user.ProjectStorageLimit, - Bandwidth: user.ProjectBandwidthLimit, - Segment: user.ProjectSegmentLimit, - } - - legalFreeze := freezes.LegalFreeze - if legalFreeze == nil { - legalFreeze = &AccountFreezeEvent{ - UserID: userID, - Type: LegalFreeze, - Limits: &AccountFreezeEventLimits{ - User: userLimits, - Projects: make(map[uuid.UUID]UsageLimits), - }, - } - } - - // If user limits have been zeroed already, we should not override what is in the freeze table. - if userLimits != (UsageLimits{}) { - legalFreeze.Limits.User = userLimits - } - - projects, err := s.projectsDB.GetOwn(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - for _, p := range projects { - projLimits := UsageLimits{} - if p.StorageLimit != nil { - projLimits.Storage = p.StorageLimit.Int64() - } - if p.BandwidthLimit != nil { - projLimits.Bandwidth = p.BandwidthLimit.Int64() - } - if p.SegmentLimit != nil { - projLimits.Segment = *p.SegmentLimit - } - // If project limits have been zeroed already, we should not override what is in the freeze table. - if projLimits != (UsageLimits{}) { - legalFreeze.Limits.Projects[p.ID] = projLimits - } - } - - _, err = s.freezeEventsDB.Upsert(ctx, legalFreeze) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, UsageLimits{}) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - for _, proj := range projects { - err := s.projectsDB.UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) if err != nil { - return ErrAccountFreeze.Wrap(err) + return err } - } - if freezes.BillingWarning != nil { - err = s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, BillingWarning) + freezes, err := tx.AccountFreezeEvents().GetAll(ctx, userID) if err != nil { - return ErrAccountFreeze.Wrap(err) + return err + } + if freezes.ViolationFreeze != nil { + return errs.New("User is already frozen due to ToS violation") } - } - status := LegalHold - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, + userLimits := UsageLimits{ + Storage: user.ProjectStorageLimit, + Bandwidth: user.ProjectBandwidthLimit, + Segment: user.ProjectSegmentLimit, + } + + legalFreeze := freezes.LegalFreeze + if legalFreeze == nil { + legalFreeze = &AccountFreezeEvent{ + UserID: userID, + Type: LegalFreeze, + Limits: &AccountFreezeEventLimits{ + User: userLimits, + Projects: make(map[uuid.UUID]UsageLimits), + }, + } + } + + // If user limits have been zeroed already, we should not override what is in the freeze table. + if userLimits != (UsageLimits{}) { + legalFreeze.Limits.User = userLimits + } + + projects, err := tx.Projects().GetOwn(ctx, userID) + if err != nil { + return err + } + for _, p := range projects { + projLimits := UsageLimits{} + if p.StorageLimit != nil { + projLimits.Storage = p.StorageLimit.Int64() + } + if p.BandwidthLimit != nil { + projLimits.Bandwidth = p.BandwidthLimit.Int64() + } + if p.SegmentLimit != nil { + projLimits.Segment = *p.SegmentLimit + } + // If project limits have been zeroed already, we should not override what is in the freeze table. + if projLimits != (UsageLimits{}) { + legalFreeze.Limits.Projects[p.ID] = projLimits + } + } + + _, err = tx.AccountFreezeEvents().Upsert(ctx, legalFreeze) + if err != nil { + return err + } + + err = tx.Users().UpdateUserProjectLimits(ctx, userID, UsageLimits{}) + if err != nil { + return err + } + + for _, proj := range projects { + err := tx.Projects().UpdateUsageLimits(ctx, proj.ID, UsageLimits{}) + if err != nil { + return err + } + } + + if freezes.BillingWarning != nil { + err = tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, BillingWarning) + if err != nil { + return err + } + } + + status := LegalHold + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return err + } + + return nil }) - if err != nil { - return ErrAccountFreeze.Wrap(errs.Combine(ErrFreezeUserStatusUpdate, err)) - } - return nil + return ErrAccountFreeze.Wrap(err) } // LegalUnfreezeUser reverses the legal freeze placed on the user specified by the given ID. func (s *AccountFreezeService) LegalUnfreezeUser(ctx context.Context, userID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - user, err := s.usersDB.Get(ctx, userID) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - - event, err := s.freezeEventsDB.Get(ctx, userID, LegalFreeze) - if errors.Is(err, sql.ErrNoRows) { - return ErrAccountFreeze.New("user is not legal-frozen") - } - - if event.Limits == nil { - return ErrAccountFreeze.New("freeze event limits are nil") - } - - for id, limits := range event.Limits.Projects { - err = s.projectsDB.UpdateUsageLimits(ctx, id, limits) + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + user, err := tx.Users().Get(ctx, userID) if err != nil { - return ErrAccountFreeze.Wrap(err) + return err } - } - err = s.usersDB.UpdateUserProjectLimits(ctx, userID, event.Limits.User) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + event, err := tx.AccountFreezeEvents().Get(ctx, userID, LegalFreeze) + if errors.Is(err, sql.ErrNoRows) { + return errs.New("user is not legal-frozen") + } - err = ErrAccountFreeze.Wrap(s.freezeEventsDB.DeleteByUserIDAndEvent(ctx, userID, LegalFreeze)) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } + if event.Limits == nil { + return errs.New("freeze event limits are nil") + } - if user.Status == LegalHold { - status := Active - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, - }) + for id, limits := range event.Limits.Projects { + err = tx.Projects().UpdateUsageLimits(ctx, id, limits) + if err != nil { + return err + } + } + + err = tx.Users().UpdateUserProjectLimits(ctx, userID, event.Limits.User) if err != nil { - return ErrAccountFreeze.Wrap(errs.Combine(ErrFreezeUserStatusUpdate, err)) + return err } - } - return nil + err = ErrAccountFreeze.Wrap(tx.AccountFreezeEvents().DeleteByUserIDAndEvent(ctx, userID, LegalFreeze)) + if err != nil { + return err + } + + if user.Status == LegalHold { + status := Active + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return err + } + } + + return nil + }) + + return ErrAccountFreeze.Wrap(err) } // GetAll returns all events for a user. @@ -684,22 +712,29 @@ func (s *AccountFreezeService) GetAllEvents(ctx context.Context, cursor FreezeEv } // EscalateBillingFreeze deactivates escalation for this freeze event and sets the user status to pending deletion. -func (s *AccountFreezeService) EscalateBillingFreeze(ctx context.Context, userID uuid.UUID, event AccountFreezeEvent) error { +func (s *AccountFreezeService) EscalateBillingFreeze(ctx context.Context, userID uuid.UUID, event AccountFreezeEvent) (err error) { + defer mon.Task()(&ctx)(&err) + event.DaysTillEscalation = nil - _, err := s.freezeEventsDB.Upsert(ctx, &event) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - status := PendingDeletion - err = s.usersDB.Update(ctx, userID, UpdateUserRequest{ - Status: &status, + err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { + _, err := tx.AccountFreezeEvents().Upsert(ctx, &event) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + + status := PendingDeletion + err = tx.Users().Update(ctx, userID, UpdateUserRequest{ + Status: &status, + }) + if err != nil { + return ErrAccountFreeze.Wrap(err) + } + + return nil }) - if err != nil { - return ErrAccountFreeze.Wrap(err) - } - return nil + return err } // TestChangeFreezeTracker changes the freeze tracker service for tests. diff --git a/satellite/console/accountfreezes_test.go b/satellite/console/accountfreezes_test.go index f31607474..87c5b1145 100644 --- a/satellite/console/accountfreezes_test.go +++ b/satellite/console/accountfreezes_test.go @@ -45,7 +45,7 @@ func TestAccountBillingFreeze(t *testing.T) { sat := planet.Satellites[0] usersDB := sat.DB.Console().Users() projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) billingFreezeGracePeriod := int(sat.Config.Console.AccountFreeze.BillingFreezeGracePeriod.Hours() / 24) @@ -116,7 +116,7 @@ func TestAccountBillingUnFreeze(t *testing.T) { sat := planet.Satellites[0] usersDB := sat.DB.Console().Users() projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) userLimits := randUsageLimits() user, err := sat.AddUser(ctx, console.CreateUser{ @@ -168,7 +168,7 @@ func TestAccountViolationFreeze(t *testing.T) { sat := planet.Satellites[0] usersDB := sat.DB.Console().Users() projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) userLimits := randUsageLimits() user, err := sat.AddUser(ctx, console.CreateUser{ @@ -250,7 +250,7 @@ func TestAccountLegalFreeze(t *testing.T) { sat := planet.Satellites[0] usersDB := sat.DB.Console().Users() projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) userLimits := randUsageLimits() user, err := sat.AddUser(ctx, console.CreateUser{ @@ -330,9 +330,7 @@ func TestRemoveAccountBillingWarning(t *testing.T) { SatelliteCount: 1, }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { sat := planet.Satellites[0] - usersDB := sat.DB.Console().Users() - projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) billingWarnGracePeriod := int(sat.Config.Console.AccountFreeze.BillingWarnGracePeriod.Hours() / 24) @@ -399,7 +397,7 @@ func TestAccountFreezeAlreadyFrozen(t *testing.T) { sat := planet.Satellites[0] usersDB := sat.DB.Console().Users() projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) userLimits := randUsageLimits() user, err := sat.AddUser(ctx, console.CreateUser{ @@ -496,10 +494,8 @@ func TestFreezeEffects(t *testing.T) { }, }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { sat := planet.Satellites[0] - usersDB := sat.DB.Console().Users() - projectsDB := sat.DB.Console().Projects() consoleService := sat.API.Console.Service - freezeService := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) + freezeService := console.NewAccountFreezeService(sat.DB.Console(), sat.API.Analytics.Service, sat.Config.Console.AccountFreeze) uplink1 := planet.Uplinks[0] user1, _, err := consoleService.GetUserByEmailWithUnverified(ctx, uplink1.User[sat.ID()].Email) diff --git a/satellite/core.go b/satellite/core.go index f45ecadf7..ceeec8732 100644 --- a/satellite/core.go +++ b/satellite/core.go @@ -530,9 +530,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, PayInvoices: console.NewInvoiceTokenPaymentObserver( peer.DB.Console(), peer.Payments.Accounts.Invoices(), console.NewAccountFreezeService( - peer.DB.Console().AccountFreezeEvents(), - peer.DB.Console().Users(), - peer.DB.Console().Projects(), + peer.DB.Console(), peer.Analytics.Service, config.Console.AccountFreeze, ), @@ -564,7 +562,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, peer.DB.Console().Users(), peer.DB.Wallets(), peer.DB.StorjscanPayments(), - console.NewAccountFreezeService(db.Console().AccountFreezeEvents(), db.Console().Users(), db.Console().Projects(), peer.Analytics.Service, config.Console.AccountFreeze), + console.NewAccountFreezeService(db.Console(), peer.Analytics.Service, config.Console.AccountFreeze), peer.Analytics.Service, config.AccountFreeze, ) diff --git a/satellite/payments/accountfreeze/chore_test.go b/satellite/payments/accountfreeze/chore_test.go index 9b45ccf21..5e85632fd 100644 --- a/satellite/payments/accountfreeze/chore_test.go +++ b/satellite/payments/accountfreeze/chore_test.go @@ -39,8 +39,7 @@ func TestAutoFreezeChore(t *testing.T) { invoicesDB := sat.Core.Payments.Accounts.Invoices() customerDB := sat.Core.DB.StripeCoinPayments().Customers() usersDB := sat.DB.Console().Users() - projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, newFreezeTrackerMock(t), sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), newFreezeTrackerMock(t), sat.Config.Console.AccountFreeze) chore := sat.Core.Payments.AccountFreeze chore.Loop.Pause() @@ -544,9 +543,7 @@ func TestAutoFreezeChore_StorjscanExclusion(t *testing.T) { stripeClient := sat.API.Payments.StripeClient invoicesDB := sat.Core.Payments.Accounts.Invoices() customerDB := sat.Core.DB.StripeCoinPayments().Customers() - usersDB := sat.DB.Console().Users() - projectsDB := sat.DB.Console().Projects() - service := console.NewAccountFreezeService(sat.DB.Console().AccountFreezeEvents(), usersDB, projectsDB, newFreezeTrackerMock(t), sat.Config.Console.AccountFreeze) + service := console.NewAccountFreezeService(sat.DB.Console(), newFreezeTrackerMock(t), sat.Config.Console.AccountFreeze) chore := sat.Core.Payments.AccountFreeze chore.Loop.Pause() diff --git a/satellite/payments/billing/chore_test.go b/satellite/payments/billing/chore_test.go index 2cc39a64e..c64426132 100644 --- a/satellite/payments/billing/chore_test.go +++ b/satellite/payments/billing/chore_test.go @@ -272,7 +272,7 @@ func TestChore_PayInvoiceObserver(t *testing.T) { err = sat.DB.Wallets().Add(ctx, userID, address) require.NoError(t, err) - freezeService := console.NewAccountFreezeService(consoleDB.AccountFreezeEvents(), consoleDB.Users(), consoleDB.Projects(), sat.Core.Analytics.Service, sat.Config.Console.AccountFreeze) + freezeService := console.NewAccountFreezeService(consoleDB, sat.Core.Analytics.Service, sat.Config.Console.AccountFreeze) choreObservers := billing.ChoreObservers{ UpgradeUser: console.NewUpgradeUserObserver(consoleDB, db.Billing(), sat.Config.Console.UsageLimits, sat.Config.Console.UserBalanceForUpgrade),