From 22f8b029b93bcff1e89cb384554908261f4551f5 Mon Sep 17 00:00:00 2001 From: Jeremy Wharton Date: Fri, 23 Jun 2023 11:34:44 -0500 Subject: [PATCH] satellite/console: fix transaction error when inviting project members The SQL transaction that inserted project invitations relied on the error result of one of its statements in order to determine whether an invitation should be updated. This was inappropriate since any errors returned from a transaction statement should end the transaction immediately. This change resolves that issue. Change-Id: I354e430df293054d8583fb4faa5dc1bcf9053836 --- satellite/console/projectinvitations.go | 12 +---- satellite/console/service.go | 14 +----- satellite/console/service_test.go | 30 +++++++----- satellite/satellitedb/dbx/project.dbx | 2 +- satellite/satellitedb/dbx/satellitedb.dbx.go | 46 +++++++++---------- satellite/satellitedb/projectinvitations.go | 30 ++---------- .../satellitedb/projectinvitations_test.go | 43 ++++++++--------- satellite/satellitedb/projectmembers_test.go | 2 +- 8 files changed, 70 insertions(+), 109 deletions(-) diff --git a/satellite/console/projectinvitations.go b/satellite/console/projectinvitations.go index 7b5d70eba..12c0a01b2 100644 --- a/satellite/console/projectinvitations.go +++ b/satellite/console/projectinvitations.go @@ -14,16 +14,14 @@ import ( // // architecture: Database type ProjectInvitations interface { - // Insert inserts a project member invitation into the database. - Insert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error) + // Upsert updates a project member invitation if it exists and inserts it otherwise. + Upsert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error) // Get returns a project member invitation from the database. Get(ctx context.Context, projectID uuid.UUID, email string) (*ProjectInvitation, error) // GetByProjectID returns all of the project member invitations for the project specified by the given ID. GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectInvitation, error) // GetByEmail returns all of the project member invitations for the specified email address. GetByEmail(ctx context.Context, email string) ([]ProjectInvitation, error) - // Update updates the project member invitation specified by the given project ID and email address. - Update(ctx context.Context, projectID uuid.UUID, email string, request UpdateProjectInvitationRequest) (*ProjectInvitation, error) // Delete removes a project member invitation from the database. Delete(ctx context.Context, projectID uuid.UUID, email string) error // DeleteBefore deletes project member invitations created prior to some time from the database. @@ -37,9 +35,3 @@ type ProjectInvitation struct { InviterID *uuid.UUID CreatedAt time.Time } - -// UpdateProjectInvitationRequest contains all fields which may be updated by ProjectInvitations.Update. -type UpdateProjectInvitationRequest struct { - CreatedAt *time.Time - InviterID *uuid.UUID -} diff --git a/satellite/console/service.go b/satellite/console/service.go index dc640d0d2..a0346e42a 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -3609,23 +3609,13 @@ func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID, // add project invites in transaction scope err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { for _, invited := range users { - invite, err := tx.ProjectInvitations().Insert(ctx, &ProjectInvitation{ + invite, err := tx.ProjectInvitations().Upsert(ctx, &ProjectInvitation{ ProjectID: projectID, Email: invited.Email, InviterID: &user.ID, }) if err != nil { - if !dbx.IsConstraintError(err) { - return err - } - now := time.Now() - invite, err = tx.ProjectInvitations().Update(ctx, projectID, invited.Email, UpdateProjectInvitationRequest{ - CreatedAt: &now, - InviterID: &user.ID, - }) - if err != nil { - return err - } + return err } token, err := s.CreateInviteToken(ctx, isMember.project.PublicID, invited.Email, invite.CreatedAt) if err != nil { diff --git a/satellite/console/service_test.go b/satellite/console/service_test.go index 852e7a761..b9c68ebcd 100644 --- a/satellite/console/service_test.go +++ b/satellite/console/service_test.go @@ -11,6 +11,7 @@ import ( "fmt" "math/rand" "sort" + "strings" "testing" "time" @@ -314,7 +315,7 @@ func TestService(t *testing.T) { require.NoError(t, err) for _, id := range []uuid.UUID{up1Proj.ID, up2Proj.ID} { - _, err = sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + _, err = sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: id, Email: invitedUser.Email, }) @@ -1975,7 +1976,7 @@ func TestProjectInvitations(t *testing.T) { } addInvite := func(t *testing.T, ctx context.Context, project *console.Project, email string) *console.ProjectInvitation { - invite, err := sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + invite, err := sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: project.ID, Email: email, InviterID: &project.OwnerID, @@ -1985,11 +1986,18 @@ func TestProjectInvitations(t *testing.T) { return invite } - expireInvite := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation) { - createdAt := time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration) - newInvite, err := sat.DB.Console().ProjectInvitations().Update(ctx, invite.ProjectID, invite.Email, console.UpdateProjectInvitationRequest{ - CreatedAt: &createdAt, - }) + setInviteDate := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation, createdAt time.Time) { + result, err := sat.DB.Testing().RawDB().ExecContext(ctx, + "UPDATE project_invitations SET created_at = $1 WHERE project_id = $2 AND email = $3", + createdAt, invite.ProjectID, strings.ToUpper(invite.Email), + ) + require.NoError(t, err) + + count, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, count) + + newInvite, err := sat.DB.Console().ProjectInvitations().Get(ctx, invite.ProjectID, invite.Email) require.NoError(t, err) *invite = *newInvite } @@ -2035,7 +2043,7 @@ func TestProjectInvitations(t *testing.T) { // expire the invitation. require.False(t, service.IsProjectInvitationExpired(&user3Invite)) oldCreatedAt := user3Invite.CreatedAt - expireInvite(t, ctx, &user3Invite) + setInviteDate(t, ctx, &user3Invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) require.True(t, service.IsProjectInvitationExpired(&user3Invite)) // resending an expired invitation should succeed. @@ -2066,7 +2074,7 @@ func TestProjectInvitations(t *testing.T) { require.Equal(t, invite.InviterID, invites[0].InviterID) require.WithinDuration(t, invite.CreatedAt, invites[0].CreatedAt, time.Second) - expireInvite(t, ctx, &invites[0]) + setInviteDate(t, ctx, &invites[0], time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) invites, err = service.GetUserProjectInvitations(ctx) require.NoError(t, err) require.Empty(t, invites) @@ -2155,7 +2163,7 @@ func TestProjectInvitations(t *testing.T) { require.NotNil(t, inviteFromToken) require.Equal(t, invite, inviteFromToken) - expireInvite(t, ctx, invite) + setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) invites, err := service.GetUserProjectInvitations(ctx) require.NoError(t, err) require.Empty(t, invites) @@ -2178,7 +2186,7 @@ func TestProjectInvitations(t *testing.T) { proj := addProject(t, ctx) invite := addInvite(t, ctx, proj, user.Email) - expireInvite(t, ctx, invite) + setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) err := service.RespondToProjectInvitation(ctx, proj.ID, console.ProjectInvitationAccept) require.True(t, console.ErrProjectInviteInvalid.Has(err)) diff --git a/satellite/satellitedb/dbx/project.dbx b/satellite/satellitedb/dbx/project.dbx index 81d76cabb..e2fec3dff 100644 --- a/satellite/satellitedb/dbx/project.dbx +++ b/satellite/satellitedb/dbx/project.dbx @@ -169,7 +169,7 @@ model project_invitation ( field created_at timestamp ( autoinsert, updatable ) ) -create project_invitation ( ) +create project_invitation ( replace ) read one ( select project_invitation diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 2348d2cd1..50dae3a3b 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -12869,7 +12869,7 @@ func (obj *pgxImpl) Create_ProjectMember(ctx context.Context, } -func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context, +func (obj *pgxImpl) Replace_ProjectInvitation(ctx context.Context, project_invitation_project_id ProjectInvitation_ProjectId_Field, project_invitation_email ProjectInvitation_Email_Field, optional ProjectInvitation_Create_Fields) ( @@ -12882,7 +12882,7 @@ func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context, __inviter_id_val := optional.InviterId.value() __created_at_val := __now - var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") + var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) ON CONFLICT ( project_id, email ) DO UPDATE SET project_id = EXCLUDED.project_id, email = EXCLUDED.email, inviter_id = EXCLUDED.inviter_id, created_at = EXCLUDED.created_at RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") var __values []interface{} __values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val) @@ -20876,7 +20876,7 @@ func (obj *pgxcockroachImpl) Create_ProjectMember(ctx context.Context, } -func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context, +func (obj *pgxcockroachImpl) Replace_ProjectInvitation(ctx context.Context, project_invitation_project_id ProjectInvitation_ProjectId_Field, project_invitation_email ProjectInvitation_Email_Field, optional ProjectInvitation_Create_Fields) ( @@ -20889,7 +20889,7 @@ func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context, __inviter_id_val := optional.InviterId.value() __created_at_val := __now - var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") + var __embed_stmt = __sqlbundle_Literal("UPSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") var __values []interface{} __values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val) @@ -28506,19 +28506,6 @@ func (rx *Rx) Create_Project(ctx context.Context, } -func (rx *Rx) Create_ProjectInvitation(ctx context.Context, - project_invitation_project_id ProjectInvitation_ProjectId_Field, - project_invitation_email ProjectInvitation_Email_Field, - optional ProjectInvitation_Create_Fields) ( - project_invitation *ProjectInvitation, err error) { - var tx *Tx - if tx, err = rx.getTx(ctx); err != nil { - return - } - return tx.Create_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional) - -} - func (rx *Rx) Create_ProjectMember(ctx context.Context, project_member_member_id ProjectMember_MemberId_Field, project_member_project_id ProjectMember_ProjectId_Field) ( @@ -29707,6 +29694,19 @@ func (rx *Rx) Replace_AccountFreezeEvent(ctx context.Context, } +func (rx *Rx) Replace_ProjectInvitation(ctx context.Context, + project_invitation_project_id ProjectInvitation_ProjectId_Field, + project_invitation_email ProjectInvitation_Email_Field, + optional ProjectInvitation_Create_Fields) ( + project_invitation *ProjectInvitation, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Replace_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional) + +} + func (rx *Rx) UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context, accounting_timestamps_name AccountingTimestamps_Name_Field, update AccountingTimestamps_Update_Fields) ( @@ -30273,12 +30273,6 @@ type Methods interface { optional Project_Create_Fields) ( project *Project, err error) - Create_ProjectInvitation(ctx context.Context, - project_invitation_project_id ProjectInvitation_ProjectId_Field, - project_invitation_email ProjectInvitation_Email_Field, - optional ProjectInvitation_Create_Fields) ( - project_invitation *ProjectInvitation, err error) - Create_ProjectMember(ctx context.Context, project_member_member_id ProjectMember_MemberId_Field, project_member_project_id ProjectMember_ProjectId_Field) ( @@ -30808,6 +30802,12 @@ type Methods interface { optional AccountFreezeEvent_Create_Fields) ( account_freeze_event *AccountFreezeEvent, err error) + Replace_ProjectInvitation(ctx context.Context, + project_invitation_project_id ProjectInvitation_ProjectId_Field, + project_invitation_email ProjectInvitation_Email_Field, + optional ProjectInvitation_Create_Fields) ( + project_invitation *ProjectInvitation, err error) + UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context, accounting_timestamps_name AccountingTimestamps_Name_Field, update AccountingTimestamps_Update_Fields) ( diff --git a/satellite/satellitedb/projectinvitations.go b/satellite/satellitedb/projectinvitations.go index ccb44c2b1..4bbd812be 100644 --- a/satellite/satellitedb/projectinvitations.go +++ b/satellite/satellitedb/projectinvitations.go @@ -22,8 +22,8 @@ type projectInvitations struct { db *satelliteDB } -// Insert inserts a project member invitation into the database. -func (invites *projectInvitations) Insert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) { +// Upsert updates a project member invitation if it exists and inserts it otherwise. +func (invites *projectInvitations) Upsert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) { defer mon.Task()(&ctx)(&err) if invite == nil { @@ -36,7 +36,7 @@ func (invites *projectInvitations) Insert(ctx context.Context, invite *console.P createFields.InviterId = dbx.ProjectInvitation_InviterId(id) } - dbxInvite, err := invites.db.Create_ProjectInvitation(ctx, + dbxInvite, err := invites.db.Replace_ProjectInvitation(ctx, dbx.ProjectInvitation_ProjectId(invite.ProjectID[:]), dbx.ProjectInvitation_Email(normalizeEmail(invite.Email)), createFields, @@ -87,30 +87,6 @@ func (invites *projectInvitations) GetByEmail(ctx context.Context, email string) return projectInvitationSliceFromDBX(dbxInvites) } -// Update updates the project member invitation specified by the given project ID and email address. -func (invites *projectInvitations) Update(ctx context.Context, projectID uuid.UUID, email string, request console.UpdateProjectInvitationRequest) (_ *console.ProjectInvitation, err error) { - defer mon.Task()(&ctx)(&err) - - update := dbx.ProjectInvitation_Update_Fields{} - if request.CreatedAt != nil { - update.CreatedAt = dbx.ProjectInvitation_CreatedAt(*request.CreatedAt) - } - if request.InviterID != nil { - update.InviterId = dbx.ProjectInvitation_InviterId((*request.InviterID)[:]) - } - - dbxInvite, err := invites.db.Update_ProjectInvitation_By_ProjectId_And_Email(ctx, - dbx.ProjectInvitation_ProjectId(projectID[:]), - dbx.ProjectInvitation_Email(normalizeEmail(email)), - update, - ) - if err != nil { - return nil, err - } - - return projectInvitationFromDBX(dbxInvite) -} - // Delete removes a project member invitation from the database. func (invites *projectInvitations) Delete(ctx context.Context, projectID uuid.UUID, email string) (err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/satellitedb/projectinvitations_test.go b/satellite/satellitedb/projectinvitations_test.go index ae032e173..a98c8953f 100644 --- a/satellite/satellitedb/projectinvitations_test.go +++ b/satellite/satellitedb/projectinvitations_test.go @@ -50,7 +50,7 @@ func TestProjectInvitations(t *testing.T) { if !t.Run("insert invitations", func(t *testing.T) { // Expect failure because no user with inviterID exists. - _, err = invitesDB.Insert(ctx, invite) + _, err = invitesDB.Upsert(ctx, invite) require.Error(t, err) _, err = db.Console().Users().Insert(ctx, &console.User{ @@ -59,19 +59,15 @@ func TestProjectInvitations(t *testing.T) { }) require.NoError(t, err) - invite, err = invitesDB.Insert(ctx, invite) + invite, err = invitesDB.Upsert(ctx, invite) require.NoError(t, err) require.WithinDuration(t, time.Now(), invite.CreatedAt, time.Minute) require.Equal(t, projID, invite.ProjectID) require.Equal(t, strings.ToUpper(email), invite.Email) - // Duplicate invitations should be rejected. - _, err = invitesDB.Insert(ctx, invite) - require.Error(t, err) - - inviteSameEmail, err = invitesDB.Insert(ctx, inviteSameEmail) + inviteSameEmail, err = invitesDB.Upsert(ctx, inviteSameEmail) require.NoError(t, err) - inviteSameProject, err = invitesDB.Insert(ctx, inviteSameProject) + inviteSameProject, err = invitesDB.Upsert(ctx, inviteSameProject) require.NoError(t, err) }) { // None of the following subtests will pass if invitation insertion failed. @@ -126,22 +122,19 @@ func TestProjectInvitations(t *testing.T) { t.Run("update invitation", func(t *testing.T) { ctx := testcontext.New(t) - req := console.UpdateProjectInvitationRequest{} - newCreatedAt := invite.CreatedAt.Add(time.Hour) - req.CreatedAt = &newCreatedAt - newInvite, err := invitesDB.Update(ctx, projID, email, req) - require.NoError(t, err) - require.Equal(t, newCreatedAt, newInvite.CreatedAt) - inviter, err := db.Console().Users().Insert(ctx, &console.User{ ID: testrand.UUID(), PasswordHash: testrand.Bytes(8), }) require.NoError(t, err) - req.InviterID = &inviter.ID - newInvite, err = invitesDB.Update(ctx, projID, email, req) + invite.InviterID = &inviter.ID + + oldCreatedAt := invite.CreatedAt + + invite, err = invitesDB.Upsert(ctx, invite) require.NoError(t, err) - require.Equal(t, inviter.ID, *newInvite.InviterID) + require.Equal(t, inviter.ID, *invite.InviterID) + require.True(t, invite.CreatedAt.After(oldCreatedAt)) }) t.Run("delete invitation", func(t *testing.T) { @@ -187,20 +180,22 @@ func TestDeleteBefore(t *testing.T) { _, err := db.Console().Projects().Insert(ctx, &console.Project{ID: projID}) require.NoError(t, err) - invite, err := invitesDB.Insert(ctx, &console.ProjectInvitation{ProjectID: projID}) + invite, err := invitesDB.Upsert(ctx, &console.ProjectInvitation{ProjectID: projID}) require.NoError(t, err) - return invite } newInvite := createInvite() oldInvite := createInvite() - oldCreatedAt := expiration.Add(-time.Second) - oldInvite, err := invitesDB.Update(ctx, oldInvite.ProjectID, oldInvite.Email, console.UpdateProjectInvitationRequest{ - CreatedAt: &oldCreatedAt, - }) + result, err := db.Testing().RawDB().ExecContext(ctx, + "UPDATE project_invitations SET created_at = $1 WHERE project_id = $2", + expiration.Add(-time.Second), oldInvite.ProjectID, + ) require.NoError(t, err) + count, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, count) require.NoError(t, invitesDB.DeleteBefore(ctx, expiration, 0, 1)) diff --git a/satellite/satellitedb/projectmembers_test.go b/satellite/satellitedb/projectmembers_test.go index de86624ab..0f77f20e3 100644 --- a/satellite/satellitedb/projectmembers_test.go +++ b/satellite/satellitedb/projectmembers_test.go @@ -36,7 +36,7 @@ func TestGetPagedWithInvitationsByProjectID(t *testing.T) { _, err = db.Console().ProjectMembers().Insert(ctx, memberUser.ID, projectID) require.NoError(t, err) - _, err = db.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + _, err = db.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: projectID, Email: "bob@mail.test", })