satellite/console: fix transaction error when inviting project members

The SQL transaction that inserted project invitations relied on the
error result of one of its statements in order to determine whether an
invitation should be updated. This was inappropriate since any errors
returned from a transaction statement should end the transaction
immediately. This change resolves that issue.

Change-Id: I354e430df293054d8583fb4faa5dc1bcf9053836
This commit is contained in:
Jeremy Wharton 2023-06-23 11:34:44 -05:00 committed by Storj Robot
parent 1b912ec167
commit 22f8b029b9
8 changed files with 70 additions and 109 deletions

View File

@ -14,16 +14,14 @@ import (
//
// architecture: Database
type ProjectInvitations interface {
// Insert inserts a project member invitation into the database.
Insert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error)
// Upsert updates a project member invitation if it exists and inserts it otherwise.
Upsert(ctx context.Context, invite *ProjectInvitation) (*ProjectInvitation, error)
// Get returns a project member invitation from the database.
Get(ctx context.Context, projectID uuid.UUID, email string) (*ProjectInvitation, error)
// GetByProjectID returns all of the project member invitations for the project specified by the given ID.
GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectInvitation, error)
// GetByEmail returns all of the project member invitations for the specified email address.
GetByEmail(ctx context.Context, email string) ([]ProjectInvitation, error)
// Update updates the project member invitation specified by the given project ID and email address.
Update(ctx context.Context, projectID uuid.UUID, email string, request UpdateProjectInvitationRequest) (*ProjectInvitation, error)
// Delete removes a project member invitation from the database.
Delete(ctx context.Context, projectID uuid.UUID, email string) error
// DeleteBefore deletes project member invitations created prior to some time from the database.
@ -37,9 +35,3 @@ type ProjectInvitation struct {
InviterID *uuid.UUID
CreatedAt time.Time
}
// UpdateProjectInvitationRequest contains all fields which may be updated by ProjectInvitations.Update.
type UpdateProjectInvitationRequest struct {
CreatedAt *time.Time
InviterID *uuid.UUID
}

View File

@ -3609,23 +3609,13 @@ func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID,
// add project invites in transaction scope
err = s.store.WithTx(ctx, func(ctx context.Context, tx DBTx) error {
for _, invited := range users {
invite, err := tx.ProjectInvitations().Insert(ctx, &ProjectInvitation{
invite, err := tx.ProjectInvitations().Upsert(ctx, &ProjectInvitation{
ProjectID: projectID,
Email: invited.Email,
InviterID: &user.ID,
})
if err != nil {
if !dbx.IsConstraintError(err) {
return err
}
now := time.Now()
invite, err = tx.ProjectInvitations().Update(ctx, projectID, invited.Email, UpdateProjectInvitationRequest{
CreatedAt: &now,
InviterID: &user.ID,
})
if err != nil {
return err
}
return err
}
token, err := s.CreateInviteToken(ctx, isMember.project.PublicID, invited.Email, invite.CreatedAt)
if err != nil {

View File

@ -11,6 +11,7 @@ import (
"fmt"
"math/rand"
"sort"
"strings"
"testing"
"time"
@ -314,7 +315,7 @@ func TestService(t *testing.T) {
require.NoError(t, err)
for _, id := range []uuid.UUID{up1Proj.ID, up2Proj.ID} {
_, err = sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{
_, err = sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{
ProjectID: id,
Email: invitedUser.Email,
})
@ -1975,7 +1976,7 @@ func TestProjectInvitations(t *testing.T) {
}
addInvite := func(t *testing.T, ctx context.Context, project *console.Project, email string) *console.ProjectInvitation {
invite, err := sat.DB.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{
invite, err := sat.DB.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{
ProjectID: project.ID,
Email: email,
InviterID: &project.OwnerID,
@ -1985,11 +1986,18 @@ func TestProjectInvitations(t *testing.T) {
return invite
}
expireInvite := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation) {
createdAt := time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration)
newInvite, err := sat.DB.Console().ProjectInvitations().Update(ctx, invite.ProjectID, invite.Email, console.UpdateProjectInvitationRequest{
CreatedAt: &createdAt,
})
setInviteDate := func(t *testing.T, ctx context.Context, invite *console.ProjectInvitation, createdAt time.Time) {
result, err := sat.DB.Testing().RawDB().ExecContext(ctx,
"UPDATE project_invitations SET created_at = $1 WHERE project_id = $2 AND email = $3",
createdAt, invite.ProjectID, strings.ToUpper(invite.Email),
)
require.NoError(t, err)
count, err := result.RowsAffected()
require.NoError(t, err)
require.EqualValues(t, 1, count)
newInvite, err := sat.DB.Console().ProjectInvitations().Get(ctx, invite.ProjectID, invite.Email)
require.NoError(t, err)
*invite = *newInvite
}
@ -2035,7 +2043,7 @@ func TestProjectInvitations(t *testing.T) {
// expire the invitation.
require.False(t, service.IsProjectInvitationExpired(&user3Invite))
oldCreatedAt := user3Invite.CreatedAt
expireInvite(t, ctx, &user3Invite)
setInviteDate(t, ctx, &user3Invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration))
require.True(t, service.IsProjectInvitationExpired(&user3Invite))
// resending an expired invitation should succeed.
@ -2066,7 +2074,7 @@ func TestProjectInvitations(t *testing.T) {
require.Equal(t, invite.InviterID, invites[0].InviterID)
require.WithinDuration(t, invite.CreatedAt, invites[0].CreatedAt, time.Second)
expireInvite(t, ctx, &invites[0])
setInviteDate(t, ctx, &invites[0], time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration))
invites, err = service.GetUserProjectInvitations(ctx)
require.NoError(t, err)
require.Empty(t, invites)
@ -2155,7 +2163,7 @@ func TestProjectInvitations(t *testing.T) {
require.NotNil(t, inviteFromToken)
require.Equal(t, invite, inviteFromToken)
expireInvite(t, ctx, invite)
setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration))
invites, err := service.GetUserProjectInvitations(ctx)
require.NoError(t, err)
require.Empty(t, invites)
@ -2178,7 +2186,7 @@ func TestProjectInvitations(t *testing.T) {
proj := addProject(t, ctx)
invite := addInvite(t, ctx, proj, user.Email)
expireInvite(t, ctx, invite)
setInviteDate(t, ctx, invite, time.Now().Add(-sat.Config.Console.ProjectInvitationExpiration))
err := service.RespondToProjectInvitation(ctx, proj.ID, console.ProjectInvitationAccept)
require.True(t, console.ErrProjectInviteInvalid.Has(err))

View File

@ -169,7 +169,7 @@ model project_invitation (
field created_at timestamp ( autoinsert, updatable )
)
create project_invitation ( )
create project_invitation ( replace )
read one (
select project_invitation

View File

@ -12869,7 +12869,7 @@ func (obj *pgxImpl) Create_ProjectMember(ctx context.Context,
}
func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context,
func (obj *pgxImpl) Replace_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
@ -12882,7 +12882,7 @@ func (obj *pgxImpl) Create_ProjectInvitation(ctx context.Context,
__inviter_id_val := optional.InviterId.value()
__created_at_val := __now
var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at")
var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) ON CONFLICT ( project_id, email ) DO UPDATE SET project_id = EXCLUDED.project_id, email = EXCLUDED.email, inviter_id = EXCLUDED.inviter_id, created_at = EXCLUDED.created_at RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at")
var __values []interface{}
__values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val)
@ -20876,7 +20876,7 @@ func (obj *pgxcockroachImpl) Create_ProjectMember(ctx context.Context,
}
func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context,
func (obj *pgxcockroachImpl) Replace_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
@ -20889,7 +20889,7 @@ func (obj *pgxcockroachImpl) Create_ProjectInvitation(ctx context.Context,
__inviter_id_val := optional.InviterId.value()
__created_at_val := __now
var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at")
var __embed_stmt = __sqlbundle_Literal("UPSERT INTO project_invitations ( project_id, email, inviter_id, created_at ) VALUES ( ?, ?, ?, ? ) RETURNING project_invitations.project_id, project_invitations.email, project_invitations.inviter_id, project_invitations.created_at")
var __values []interface{}
__values = append(__values, __project_id_val, __email_val, __inviter_id_val, __created_at_val)
@ -28506,19 +28506,6 @@ func (rx *Rx) Create_Project(ctx context.Context,
}
func (rx *Rx) Create_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
project_invitation *ProjectInvitation, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Create_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional)
}
func (rx *Rx) Create_ProjectMember(ctx context.Context,
project_member_member_id ProjectMember_MemberId_Field,
project_member_project_id ProjectMember_ProjectId_Field) (
@ -29707,6 +29694,19 @@ func (rx *Rx) Replace_AccountFreezeEvent(ctx context.Context,
}
func (rx *Rx) Replace_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
project_invitation *ProjectInvitation, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Replace_ProjectInvitation(ctx, project_invitation_project_id, project_invitation_email, optional)
}
func (rx *Rx) UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context,
accounting_timestamps_name AccountingTimestamps_Name_Field,
update AccountingTimestamps_Update_Fields) (
@ -30273,12 +30273,6 @@ type Methods interface {
optional Project_Create_Fields) (
project *Project, err error)
Create_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
project_invitation *ProjectInvitation, err error)
Create_ProjectMember(ctx context.Context,
project_member_member_id ProjectMember_MemberId_Field,
project_member_project_id ProjectMember_ProjectId_Field) (
@ -30808,6 +30802,12 @@ type Methods interface {
optional AccountFreezeEvent_Create_Fields) (
account_freeze_event *AccountFreezeEvent, err error)
Replace_ProjectInvitation(ctx context.Context,
project_invitation_project_id ProjectInvitation_ProjectId_Field,
project_invitation_email ProjectInvitation_Email_Field,
optional ProjectInvitation_Create_Fields) (
project_invitation *ProjectInvitation, err error)
UpdateNoReturn_AccountingTimestamps_By_Name(ctx context.Context,
accounting_timestamps_name AccountingTimestamps_Name_Field,
update AccountingTimestamps_Update_Fields) (

View File

@ -22,8 +22,8 @@ type projectInvitations struct {
db *satelliteDB
}
// Insert inserts a project member invitation into the database.
func (invites *projectInvitations) Insert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) {
// Upsert updates a project member invitation if it exists and inserts it otherwise.
func (invites *projectInvitations) Upsert(ctx context.Context, invite *console.ProjectInvitation) (_ *console.ProjectInvitation, err error) {
defer mon.Task()(&ctx)(&err)
if invite == nil {
@ -36,7 +36,7 @@ func (invites *projectInvitations) Insert(ctx context.Context, invite *console.P
createFields.InviterId = dbx.ProjectInvitation_InviterId(id)
}
dbxInvite, err := invites.db.Create_ProjectInvitation(ctx,
dbxInvite, err := invites.db.Replace_ProjectInvitation(ctx,
dbx.ProjectInvitation_ProjectId(invite.ProjectID[:]),
dbx.ProjectInvitation_Email(normalizeEmail(invite.Email)),
createFields,
@ -87,30 +87,6 @@ func (invites *projectInvitations) GetByEmail(ctx context.Context, email string)
return projectInvitationSliceFromDBX(dbxInvites)
}
// Update updates the project member invitation specified by the given project ID and email address.
func (invites *projectInvitations) Update(ctx context.Context, projectID uuid.UUID, email string, request console.UpdateProjectInvitationRequest) (_ *console.ProjectInvitation, err error) {
defer mon.Task()(&ctx)(&err)
update := dbx.ProjectInvitation_Update_Fields{}
if request.CreatedAt != nil {
update.CreatedAt = dbx.ProjectInvitation_CreatedAt(*request.CreatedAt)
}
if request.InviterID != nil {
update.InviterId = dbx.ProjectInvitation_InviterId((*request.InviterID)[:])
}
dbxInvite, err := invites.db.Update_ProjectInvitation_By_ProjectId_And_Email(ctx,
dbx.ProjectInvitation_ProjectId(projectID[:]),
dbx.ProjectInvitation_Email(normalizeEmail(email)),
update,
)
if err != nil {
return nil, err
}
return projectInvitationFromDBX(dbxInvite)
}
// Delete removes a project member invitation from the database.
func (invites *projectInvitations) Delete(ctx context.Context, projectID uuid.UUID, email string) (err error) {
defer mon.Task()(&ctx)(&err)

View File

@ -50,7 +50,7 @@ func TestProjectInvitations(t *testing.T) {
if !t.Run("insert invitations", func(t *testing.T) {
// Expect failure because no user with inviterID exists.
_, err = invitesDB.Insert(ctx, invite)
_, err = invitesDB.Upsert(ctx, invite)
require.Error(t, err)
_, err = db.Console().Users().Insert(ctx, &console.User{
@ -59,19 +59,15 @@ func TestProjectInvitations(t *testing.T) {
})
require.NoError(t, err)
invite, err = invitesDB.Insert(ctx, invite)
invite, err = invitesDB.Upsert(ctx, invite)
require.NoError(t, err)
require.WithinDuration(t, time.Now(), invite.CreatedAt, time.Minute)
require.Equal(t, projID, invite.ProjectID)
require.Equal(t, strings.ToUpper(email), invite.Email)
// Duplicate invitations should be rejected.
_, err = invitesDB.Insert(ctx, invite)
require.Error(t, err)
inviteSameEmail, err = invitesDB.Insert(ctx, inviteSameEmail)
inviteSameEmail, err = invitesDB.Upsert(ctx, inviteSameEmail)
require.NoError(t, err)
inviteSameProject, err = invitesDB.Insert(ctx, inviteSameProject)
inviteSameProject, err = invitesDB.Upsert(ctx, inviteSameProject)
require.NoError(t, err)
}) {
// None of the following subtests will pass if invitation insertion failed.
@ -126,22 +122,19 @@ func TestProjectInvitations(t *testing.T) {
t.Run("update invitation", func(t *testing.T) {
ctx := testcontext.New(t)
req := console.UpdateProjectInvitationRequest{}
newCreatedAt := invite.CreatedAt.Add(time.Hour)
req.CreatedAt = &newCreatedAt
newInvite, err := invitesDB.Update(ctx, projID, email, req)
require.NoError(t, err)
require.Equal(t, newCreatedAt, newInvite.CreatedAt)
inviter, err := db.Console().Users().Insert(ctx, &console.User{
ID: testrand.UUID(),
PasswordHash: testrand.Bytes(8),
})
require.NoError(t, err)
req.InviterID = &inviter.ID
newInvite, err = invitesDB.Update(ctx, projID, email, req)
invite.InviterID = &inviter.ID
oldCreatedAt := invite.CreatedAt
invite, err = invitesDB.Upsert(ctx, invite)
require.NoError(t, err)
require.Equal(t, inviter.ID, *newInvite.InviterID)
require.Equal(t, inviter.ID, *invite.InviterID)
require.True(t, invite.CreatedAt.After(oldCreatedAt))
})
t.Run("delete invitation", func(t *testing.T) {
@ -187,20 +180,22 @@ func TestDeleteBefore(t *testing.T) {
_, err := db.Console().Projects().Insert(ctx, &console.Project{ID: projID})
require.NoError(t, err)
invite, err := invitesDB.Insert(ctx, &console.ProjectInvitation{ProjectID: projID})
invite, err := invitesDB.Upsert(ctx, &console.ProjectInvitation{ProjectID: projID})
require.NoError(t, err)
return invite
}
newInvite := createInvite()
oldInvite := createInvite()
oldCreatedAt := expiration.Add(-time.Second)
oldInvite, err := invitesDB.Update(ctx, oldInvite.ProjectID, oldInvite.Email, console.UpdateProjectInvitationRequest{
CreatedAt: &oldCreatedAt,
})
result, err := db.Testing().RawDB().ExecContext(ctx,
"UPDATE project_invitations SET created_at = $1 WHERE project_id = $2",
expiration.Add(-time.Second), oldInvite.ProjectID,
)
require.NoError(t, err)
count, err := result.RowsAffected()
require.NoError(t, err)
require.EqualValues(t, 1, count)
require.NoError(t, invitesDB.DeleteBefore(ctx, expiration, 0, 1))

View File

@ -36,7 +36,7 @@ func TestGetPagedWithInvitationsByProjectID(t *testing.T) {
_, err = db.Console().ProjectMembers().Insert(ctx, memberUser.ID, projectID)
require.NoError(t, err)
_, err = db.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{
_, err = db.Console().ProjectInvitations().Upsert(ctx, &console.ProjectInvitation{
ProjectID: projectID,
Email: "bob@mail.test",
})