diff --git a/satellite/console/projectinvitations.go b/satellite/console/projectinvitations.go index 7b5d70eba..12c0a01b2 100644 --- a/satellite/console/projectinvitations.go +++ b/satellite/console/projectinvitations.go @@ -14,16 +14,14 @@ import ( // // architecture: Database type ProjectInvitations interface { - // Insert inserts a project member invitation into the database. - Insert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error) + // Upsert updates a project member invitation if it exists and inserts it otherwise. + Upsert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error) // Get returns a project member invitation from the database. Get(ctx context.Context, projectID uuid.UUID, email string) (*ProjectInvitation, error) // GetByProjectID returns all of the project member invitations for the project specified by the given ID. GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectInvitation, error) // GetByEmail returns all of the project member invitations for the specified email address. GetByEmail(ctx context.Context, email string) ([]ProjectInvitation, error) - // Update updates the project member invitation specified by the given project ID and email address. - Update(ctx context.Context, projectID uuid.UUID, email string, request UpdateProjectInvitationRequest) (*ProjectInvitation, error) // Delete removes a project member invitation from the database. Delete(ctx context.Context, projectID uuid.UUID, email string) error // DeleteBefore deletes project member invitations created prior to some time from the database. @@ -37,9 +35,3 @@ type ProjectInvitation struct { InviterID *uuid.UUID CreatedAt time.Time } - -// UpdateProjectInvitationRequest contains all fields which may be updated by ProjectInvitations.Update. -type UpdateProjectInvitationRequest struct { - CreatedAt *time.Time - InviterID *uuid.UUID -} diff --git a/satellite/console/service.go b/satellite/console/service.go index dc640d0d2..a0346e42a 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -3609,23 +3609,13 @@ func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID, // add project invites in transaction scope err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error { for _, invited := range users { - invite, err := tx.ProjectInvitations().Insert(ctx, &ProjectInvitation{ + invite, err := tx.ProjectInvitations().Upsert(ctx, &ProjectInvitation{ ProjectID: projectID, Email: invited.Email, InviterID: &user.ID, }) if err != nil { - if !dbx.IsConstraintError(err) { - return err - } - now := time.Now() - invite, err = tx.ProjectInvitations().Update(ctx, projectID, invited.Email, UpdateProjectInvitationRequest{ - CreatedAt: &now, - InviterID: &user.ID, - }) - if err != nil { - return err - } + return err } token, err := s.CreateInviteToken(ctx, isMember.project.PublicID, invited.Email, invite.CreatedAt) if err != nil { diff --git a/satellite/console/service_test.go b/satellite/console/service_test.go index 852e7a761..b9c68ebcd 100644 --- a/satellite/console/service_test.go +++ b/satellite/console/service_test.go @@ -11,6 +11,7 @@ import ( "fmt" "math/rand" "sort" + "strings" "testing" "time" @@ -314,7 +315,7 @@ func TestService(t *testing.T) { require.NoError(t, err) for _, id := range []uuid.UUID{up1Proj.ID, up2Proj.ID} { - _, err = sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + _, err = sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: id, Email: invitedUser.Email, }) @@ -1975,7 +1976,7 @@ func TestProjectInvitations(t *testing.T) { } addInvite := func(t *testing.T, ctx context.Context, project *console.Project, email string) *console.ProjectInvitation { - invite, err := sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + invite, err := sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: project.ID, Email: email, InviterID: &project.OwnerID, @@ -1985,11 +1986,18 @@ func TestProjectInvitations(t *testing.T) { return invite } - expireInvite := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation) { - createdAt := time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration) - newInvite, err := sat.DB.Console().ProjectInvitations().Update(ctx, invite.ProjectID, invite.Email, console.UpdateProjectInvitationRequest{ - CreatedAt: &createdAt, - }) + setInviteDate := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation, createdAt time.Time) { + result, err := sat.DB.Testing().RawDB().ExecContext(ctx, + "UPDATE project_invitations SET created_at = $1 WHERE project_id = $2 AND email = $3", + createdAt, invite.ProjectID, strings.ToUpper(invite.Email), + ) + require.NoError(t, err) + + count, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, count) + + newInvite, err := sat.DB.Console().ProjectInvitations().Get(ctx, invite.ProjectID, invite.Email) require.NoError(t, err) *invite = *newInvite } @@ -2035,7 +2043,7 @@ func TestProjectInvitations(t *testing.T) { // expire the invitation. require.False(t, service.IsProjectInvitationExpired(&user3Invite)) oldCreatedAt := user3Invite.CreatedAt - expireInvite(t, ctx, &user3Invite) + setInviteDate(t, ctx, &user3Invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) require.True(t, service.IsProjectInvitationExpired(&user3Invite)) // resending an expired invitation should succeed. @@ -2066,7 +2074,7 @@ func TestProjectInvitations(t *testing.T) { require.Equal(t, invite.InviterID, invites[0].InviterID) require.WithinDuration(t, invite.CreatedAt, invites[0].CreatedAt, time.Second) - expireInvite(t, ctx, &invites[0]) + setInviteDate(t, ctx, &invites[0], time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) invites, err = service.GetUserProjectInvitations(ctx) require.NoError(t, err) require.Empty(t, invites) @@ -2155,7 +2163,7 @@ func TestProjectInvitations(t *testing.T) { require.NotNil(t, inviteFromToken) require.Equal(t, invite, inviteFromToken) - expireInvite(t, ctx, invite) + setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) invites, err := service.GetUserProjectInvitations(ctx) require.NoError(t, err) require.Empty(t, invites) @@ -2178,7 +2186,7 @@ func TestProjectInvitations(t *testing.T) { proj := addProject(t, ctx) invite := addInvite(t, ctx, proj, user.Email) - expireInvite(t, ctx, invite) + setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)) err := service.RespondToProjectInvitation(ctx, proj.ID, console.ProjectInvitationAccept) require.True(t, console.ErrProjectInviteInvalid.Has(err)) diff --git a/satellite/satellitedb/dbx/project.dbx b/satellite/satellitedb/dbx/project.dbx index 81d76cabb..e2fec3dff 100644 --- a/satellite/satellitedb/dbx/project.dbx +++ b/satellite/satellitedb/dbx/project.dbx @@ -169,7 +169,7 @@ model project_invitation ( field created_at timestamp ( autoinsert, updatable ) ) -create project_invitation ( ) +create project_invitation ( replace ) read one ( select project_invitation diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 2348d2cd1..50dae3a3b 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -12869,7 +12869,7 @@ func (obj *pgxImpl) Create_ProjectMember(ctx context.Context, } -func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context, +func (obj *pgxImpl) Replace_ProjectInvitation(ctx context.Context, project_invitation_project_id ProjectInvitation_ProjectId_Field, project_invitation_email ProjectInvitation_Email_Field, optional ProjectInvitation_Create_Fields) ( @@ -12882,7 +12882,7 @@ func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context, __inviter_id_val := optional.InviterId.value() __created_at_val := __now - var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") + var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) ON CONFLICT ( project_id, email ) DO UPDATE SET project_id = EXCLUDED.project_id, email = EXCLUDED.email, inviter_id = EXCLUDED.inviter_id, created_at = EXCLUDED.created_at RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") var __values []interface{} __values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val) @@ -20876,7 +20876,7 @@ func (obj *pgxcockroachImpl) Create_ProjectMember(ctx context.Context, } -func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context, +func (obj *pgxcockroachImpl) Replace_ProjectInvitation(ctx context.Context, project_invitation_project_id ProjectInvitation_ProjectId_Field, project_invitation_email ProjectInvitation_Email_Field, optional ProjectInvitation_Create_Fields) ( @@ -20889,7 +20889,7 @@ func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context, __inviter_id_val := optional.InviterId.value() __created_at_val := __now - var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") + var __embed_stmt = __sqlbundle_Literal("UPSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at") var __values []interface{} __values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val) @@ -28506,19 +28506,6 @@ func (rx *Rx) Create_Project(ctx context.Context, } -func (rx *Rx) Create_ProjectInvitation(ctx context.Context, - project_invitation_project_id ProjectInvitation_ProjectId_Field, - project_invitation_email ProjectInvitation_Email_Field, - optional ProjectInvitation_Create_Fields) ( - project_invitation *ProjectInvitation, err error) { - var tx *Tx - if tx, err = rx.getTx(ctx); err != nil { - return - } - return tx.Create_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional) - -} - func (rx *Rx) Create_ProjectMember(ctx context.Context, project_member_member_id ProjectMember_MemberId_Field, project_member_project_id ProjectMember_ProjectId_Field) ( @@ -29707,6 +29694,19 @@ func (rx *Rx) Replace_AccountFreezeEvent(ctx context.Context, } +func (rx *Rx) Replace_ProjectInvitation(ctx context.Context, + project_invitation_project_id ProjectInvitation_ProjectId_Field, + project_invitation_email ProjectInvitation_Email_Field, + optional ProjectInvitation_Create_Fields) ( + project_invitation *ProjectInvitation, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Replace_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional) + +} + func (rx *Rx) UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context, accounting_timestamps_name AccountingTimestamps_Name_Field, update AccountingTimestamps_Update_Fields) ( @@ -30273,12 +30273,6 @@ type Methods interface { optional Project_Create_Fields) ( project *Project, err error) - Create_ProjectInvitation(ctx context.Context, - project_invitation_project_id ProjectInvitation_ProjectId_Field, - project_invitation_email ProjectInvitation_Email_Field, - optional ProjectInvitation_Create_Fields) ( - project_invitation *ProjectInvitation, err error) - Create_ProjectMember(ctx context.Context, project_member_member_id ProjectMember_MemberId_Field, project_member_project_id ProjectMember_ProjectId_Field) ( @@ -30808,6 +30802,12 @@ type Methods interface { optional AccountFreezeEvent_Create_Fields) ( account_freeze_event *AccountFreezeEvent, err error) + Replace_ProjectInvitation(ctx context.Context, + project_invitation_project_id ProjectInvitation_ProjectId_Field, + project_invitation_email ProjectInvitation_Email_Field, + optional ProjectInvitation_Create_Fields) ( + project_invitation *ProjectInvitation, err error) + UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context, accounting_timestamps_name AccountingTimestamps_Name_Field, update AccountingTimestamps_Update_Fields) ( diff --git a/satellite/satellitedb/projectinvitations.go b/satellite/satellitedb/projectinvitations.go index ccb44c2b1..4bbd812be 100644 --- a/satellite/satellitedb/projectinvitations.go +++ b/satellite/satellitedb/projectinvitations.go @@ -22,8 +22,8 @@ type projectInvitations struct { db *satelliteDB } -// Insert inserts a project member invitation into the database. -func (invites *projectInvitations) Insert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) { +// Upsert updates a project member invitation if it exists and inserts it otherwise. +func (invites *projectInvitations) Upsert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) { defer mon.Task()(&ctx)(&err) if invite == nil { @@ -36,7 +36,7 @@ func (invites *projectInvitations) Insert(ctx context.Context, invite *console.P createFields.InviterId = dbx.ProjectInvitation_InviterId(id) } - dbxInvite, err := invites.db.Create_ProjectInvitation(ctx, + dbxInvite, err := invites.db.Replace_ProjectInvitation(ctx, dbx.ProjectInvitation_ProjectId(invite.ProjectID[:]), dbx.ProjectInvitation_Email(normalizeEmail(invite.Email)), createFields, @@ -87,30 +87,6 @@ func (invites *projectInvitations) GetByEmail(ctx context.Context, email string) return projectInvitationSliceFromDBX(dbxInvites) } -// Update updates the project member invitation specified by the given project ID and email address. -func (invites *projectInvitations) Update(ctx context.Context, projectID uuid.UUID, email string, request console.UpdateProjectInvitationRequest) (_ *console.ProjectInvitation, err error) { - defer mon.Task()(&ctx)(&err) - - update := dbx.ProjectInvitation_Update_Fields{} - if request.CreatedAt != nil { - update.CreatedAt = dbx.ProjectInvitation_CreatedAt(*request.CreatedAt) - } - if request.InviterID != nil { - update.InviterId = dbx.ProjectInvitation_InviterId((*request.InviterID)[:]) - } - - dbxInvite, err := invites.db.Update_ProjectInvitation_By_ProjectId_And_Email(ctx, - dbx.ProjectInvitation_ProjectId(projectID[:]), - dbx.ProjectInvitation_Email(normalizeEmail(email)), - update, - ) - if err != nil { - return nil, err - } - - return projectInvitationFromDBX(dbxInvite) -} - // Delete removes a project member invitation from the database. func (invites *projectInvitations) Delete(ctx context.Context, projectID uuid.UUID, email string) (err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/satellitedb/projectinvitations_test.go b/satellite/satellitedb/projectinvitations_test.go index ae032e173..a98c8953f 100644 --- a/satellite/satellitedb/projectinvitations_test.go +++ b/satellite/satellitedb/projectinvitations_test.go @@ -50,7 +50,7 @@ func TestProjectInvitations(t *testing.T) { if !t.Run("insert invitations", func(t *testing.T) { // Expect failure because no user with inviterID exists. - _, err = invitesDB.Insert(ctx, invite) + _, err = invitesDB.Upsert(ctx, invite) require.Error(t, err) _, err = db.Console().Users().Insert(ctx, &console.User{ @@ -59,19 +59,15 @@ func TestProjectInvitations(t *testing.T) { }) require.NoError(t, err) - invite, err = invitesDB.Insert(ctx, invite) + invite, err = invitesDB.Upsert(ctx, invite) require.NoError(t, err) require.WithinDuration(t, time.Now(), invite.CreatedAt, time.Minute) require.Equal(t, projID, invite.ProjectID) require.Equal(t, strings.ToUpper(email), invite.Email) - // Duplicate invitations should be rejected. - _, err = invitesDB.Insert(ctx, invite) - require.Error(t, err) - - inviteSameEmail, err = invitesDB.Insert(ctx, inviteSameEmail) + inviteSameEmail, err = invitesDB.Upsert(ctx, inviteSameEmail) require.NoError(t, err) - inviteSameProject, err = invitesDB.Insert(ctx, inviteSameProject) + inviteSameProject, err = invitesDB.Upsert(ctx, inviteSameProject) require.NoError(t, err) }) { // None of the following subtests will pass if invitation insertion failed. @@ -126,22 +122,19 @@ func TestProjectInvitations(t *testing.T) { t.Run("update invitation", func(t *testing.T) { ctx := testcontext.New(t) - req := console.UpdateProjectInvitationRequest{} - newCreatedAt := invite.CreatedAt.Add(time.Hour) - req.CreatedAt = &newCreatedAt - newInvite, err := invitesDB.Update(ctx, projID, email, req) - require.NoError(t, err) - require.Equal(t, newCreatedAt, newInvite.CreatedAt) - inviter, err := db.Console().Users().Insert(ctx, &console.User{ ID: testrand.UUID(), PasswordHash: testrand.Bytes(8), }) require.NoError(t, err) - req.InviterID = &inviter.ID - newInvite, err = invitesDB.Update(ctx, projID, email, req) + invite.InviterID = &inviter.ID + + oldCreatedAt := invite.CreatedAt + + invite, err = invitesDB.Upsert(ctx, invite) require.NoError(t, err) - require.Equal(t, inviter.ID, *newInvite.InviterID) + require.Equal(t, inviter.ID, *invite.InviterID) + require.True(t, invite.CreatedAt.After(oldCreatedAt)) }) t.Run("delete invitation", func(t *testing.T) { @@ -187,20 +180,22 @@ func TestDeleteBefore(t *testing.T) { _, err := db.Console().Projects().Insert(ctx, &console.Project{ID: projID}) require.NoError(t, err) - invite, err := invitesDB.Insert(ctx, &console.ProjectInvitation{ProjectID: projID}) + invite, err := invitesDB.Upsert(ctx, &console.ProjectInvitation{ProjectID: projID}) require.NoError(t, err) - return invite } newInvite := createInvite() oldInvite := createInvite() - oldCreatedAt := expiration.Add(-time.Second) - oldInvite, err := invitesDB.Update(ctx, oldInvite.ProjectID, oldInvite.Email, console.UpdateProjectInvitationRequest{ - CreatedAt: &oldCreatedAt, - }) + result, err := db.Testing().RawDB().ExecContext(ctx, + "UPDATE project_invitations SET created_at = $1 WHERE project_id = $2", + expiration.Add(-time.Second), oldInvite.ProjectID, + ) require.NoError(t, err) + count, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, count) require.NoError(t, invitesDB.DeleteBefore(ctx, expiration, 0, 1)) diff --git a/satellite/satellitedb/projectmembers_test.go b/satellite/satellitedb/projectmembers_test.go index de86624ab..0f77f20e3 100644 --- a/satellite/satellitedb/projectmembers_test.go +++ b/satellite/satellitedb/projectmembers_test.go @@ -36,7 +36,7 @@ func TestGetPagedWithInvitationsByProjectID(t *testing.T) { _, err = db.Console().ProjectMembers().Insert(ctx, memberUser.ID, projectID) require.NoError(t, err) - _, err = db.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + _, err = db.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{ ProjectID: projectID, Email: "bob@mail.test", })