diff --git a/satellite/console/projectmembers.go b/satellite/console/projectmembers.go index 24b2bd540..17d58f7aa 100644 --- a/satellite/console/projectmembers.go +++ b/satellite/console/projectmembers.go @@ -16,8 +16,11 @@ import ( type ProjectMembers interface { // GetByMemberID is a method for querying project members from the database by memberID. GetByMemberID(ctx context.Context, memberID uuid.UUID) ([]ProjectMember, error) - // GetPagedByProjectID is a method for querying project members from the database by projectID and cursor + // GetPagedByProjectID is a method for querying project members from the database by projectID and cursor. + // TODO: Remove once all uses have been replaced by GetPagedWithInvitationsByProjectID. GetPagedByProjectID(ctx context.Context, projectID uuid.UUID, cursor ProjectMembersCursor) (*ProjectMembersPage, error) + // GetPagedWithInvitationsByProjectID is a method for querying project members and invitations from the database by projectID and cursor. + GetPagedWithInvitationsByProjectID(ctx context.Context, projectID uuid.UUID, cursor ProjectMembersCursor) (*ProjectMembersPage, error) // Insert is a method for inserting project member into the database. Insert(ctx context.Context, memberID, projectID uuid.UUID) (*ProjectMember, error) // Delete is a method for deleting project member by memberID and projectID from the database. @@ -43,9 +46,10 @@ type ProjectMembersCursor struct { OrderDirection OrderDirection } -// ProjectMembersPage represent project members page result. +// ProjectMembersPage represents a page of project members and invitations. type ProjectMembersPage struct { - ProjectMembers []ProjectMember + ProjectMembers []ProjectMember + ProjectInvitations []ProjectInvitation Search string Limit uint diff --git a/satellite/satellitedb/projectmembers.go b/satellite/satellitedb/projectmembers.go index c0641d5fa..b75d2c8e8 100644 --- a/satellite/satellitedb/projectmembers.go +++ b/satellite/satellitedb/projectmembers.go @@ -6,6 +6,7 @@ package satellitedb import ( "context" "strings" + "time" "github.com/zeebo/errs" @@ -35,6 +36,7 @@ func (pm *projectMembers) GetByMemberID(ctx context.Context, memberID uuid.UUID) } // GetByProjectID is a method for querying project members from the database by projectID, offset and limit. +// TODO: Remove once all uses have been replaced by GetPagedWithInvitationsByProjectID. func (pm *projectMembers) GetPagedByProjectID(ctx context.Context, projectID uuid.UUID, cursor console.ProjectMembersCursor) (_ *console.ProjectMembersPage, err error) { defer mon.Task()(&ctx)(&err) @@ -58,12 +60,12 @@ func (pm *projectMembers) GetPagedByProjectID(ctx context.Context, projectID uui countQuery := pm.db.Rebind(` SELECT COUNT(*) - FROM project_members pm + FROM project_members pm INNER JOIN users u ON pm.member_id = u.id WHERE pm.project_id = ? - AND ( u.email LIKE ? OR + AND ( u.email LIKE ? OR u.full_name LIKE ? OR - u.short_name LIKE ? + u.short_name LIKE ? )`) countRow := pm.db.QueryRowContext(ctx, @@ -93,7 +95,7 @@ func (pm *projectMembers) GetPagedByProjectID(ctx context.Context, projectID uui u.full_name LIKE ? OR u.short_name LIKE ? ) ORDER BY ` + sanitizedOrderColumnName(cursor.Order) + ` - ` + sanitizeOrderDirectionName(page.OrderDirection) + ` + ` + sanitizeOrderDirectionName(page.OrderDirection) + ` LIMIT ? OFFSET ?`) rows, err := pm.db.QueryContext(ctx, @@ -141,6 +143,155 @@ func (pm *projectMembers) GetPagedByProjectID(ctx context.Context, projectID uui return page, err } +// GetPagedWithInvitationsByProjectID is a method for querying project members and invitations from the database by projectID, offset and limit. +func (pm *projectMembers) GetPagedWithInvitationsByProjectID(ctx context.Context, projectID uuid.UUID, cursor console.ProjectMembersCursor) (_ *console.ProjectMembersPage, err error) { + defer mon.Task()(&ctx)(&err) + + search := "%" + strings.ReplaceAll(cursor.Search, " ", "%") + "%" + + if cursor.Limit > 50 { + cursor.Limit = 50 + } + + if cursor.Limit == 0 { + return nil, errs.New("limit cannot be 0") + } + + if cursor.Page == 0 { + return nil, errs.New("page cannot be 0") + } + + page := &console.ProjectMembersPage{ + Search: cursor.Search, + Limit: cursor.Limit, + Offset: uint64((cursor.Page - 1) * cursor.Limit), + Order: cursor.Order, + OrderDirection: cursor.OrderDirection, + } + + countQuery := ` + SELECT ( + SELECT COUNT(*) + FROM project_members pm + INNER JOIN users u ON pm.member_id = u.id + WHERE pm.project_id = $1 + AND ( + u.email ILIKE $2 OR + u.full_name ILIKE $2 OR + u.short_name ILIKE $2 + ) + ) + ( + SELECT COUNT(*) + FROM project_invitations + WHERE project_id = $1 + AND email ILIKE $2 + )` + + countRow := pm.db.QueryRowContext(ctx, + countQuery, + projectID[:], + search) + + err = countRow.Scan(&page.TotalCount) + if err != nil { + return nil, err + } + if page.TotalCount == 0 { + return page, nil + } + if page.Offset > page.TotalCount-1 { + return nil, errs.New("page is out of range") + } + + membersQuery := ` + SELECT member_id, project_id, created_at, email, inviter_id FROM ( + ( + SELECT pm.member_id, pm.project_id, pm.created_at, u.email, u.full_name, NULL as inviter_id + FROM project_members pm + INNER JOIN users u ON pm.member_id = u.id + WHERE pm.project_id = $1 + AND ( + u.email ILIKE $2 OR + u.full_name ILIKE $2 OR + u.short_name ILIKE $2 + ) + ) UNION ALL ( + SELECT NULL as member_id, project_id, created_at, LOWER(email) as email, LOWER(SPLIT_PART(email, '@', 1)) as full_name, inviter_id + FROM project_invitations pi + WHERE project_id = $1 + AND email ILIKE $2 + ) + ) results + ` + projectMembersSortClause(cursor.Order, page.OrderDirection) + ` + LIMIT $3 OFFSET $4` + + rows, err := pm.db.QueryContext(ctx, + membersQuery, + projectID[:], + search, + page.Limit, + page.Offset, + ) + if err != nil { + return nil, err + } + defer func() { + err = errs.Combine(err, rows.Close()) + }() + + for rows.Next() { + var ( + memberID NullUUID + projectID uuid.UUID + createdAt time.Time + email string + inviterID NullUUID + ) + + err = rows.Scan( + &memberID, + &projectID, + &createdAt, + &email, + &inviterID, + ) + if err != nil { + return nil, err + } + + if memberID.Valid { + page.ProjectMembers = append(page.ProjectMembers, console.ProjectMember{ + MemberID: memberID.UUID, + ProjectID: projectID, + CreatedAt: createdAt, + }) + } else { + invite := console.ProjectInvitation{ + ProjectID: projectID, + Email: email, + CreatedAt: createdAt, + } + if inviterID.Valid { + invite.InviterID = &inviterID.UUID + } + page.ProjectInvitations = append(page.ProjectInvitations, invite) + } + + } + if err := rows.Err(); err != nil { + return nil, err + } + + page.PageCount = uint(page.TotalCount / uint64(cursor.Limit)) + if page.TotalCount%uint64(cursor.Limit) != 0 { + page.PageCount++ + } + + page.CurrentPage = cursor.Page + + return page, err +} + // Insert is a method for inserting project member into the database. func (pm *projectMembers) Insert(ctx context.Context, memberID, projectID uuid.UUID) (_ *console.ProjectMember, err error) { defer mon.Task()(&ctx)(&err) @@ -210,6 +361,22 @@ func sanitizeOrderDirectionName(pmo console.OrderDirection) string { return "ASC" } +// projectMembersSortClause returns what ORDER BY clause should be used when sorting project member results. +func projectMembersSortClause(order console.ProjectMemberOrder, direction console.OrderDirection) string { + dirStr := "ASC" + if direction == console.Descending { + dirStr = "DESC" + } + + switch order { + case console.Email: + return "ORDER BY LOWER(email) " + dirStr + case console.Created: + return "ORDER BY created_at " + dirStr + ", LOWER(email)" + } + return "ORDER BY LOWER(full_name) " + dirStr + ", LOWER(email)" +} + // projectMembersFromDbxSlice is used for creating []ProjectMember entities from autogenerated []*dbx.ProjectMember struct. func projectMembersFromDbxSlice(ctx context.Context, projectMembersDbx []*dbx.ProjectMember) (_ []console.ProjectMember, err error) { defer mon.Task()(&ctx)(&err) diff --git a/satellite/satellitedb/projectmembers_test.go b/satellite/satellitedb/projectmembers_test.go index 1bd3c2a77..de86624ab 100644 --- a/satellite/satellitedb/projectmembers_test.go +++ b/satellite/satellitedb/projectmembers_test.go @@ -1,29 +1,183 @@ // Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. -package satellitedb +package satellitedb_test import ( + "fmt" "testing" + "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "storj.io/common/testcontext" + "storj.io/common/testrand" + "storj.io/common/uuid" + "storj.io/storj/satellite" "storj.io/storj/satellite/console" + "storj.io/storj/satellite/satellitedb/satellitedbtest" ) -func TestSanitizedOrderColumnName(t *testing.T) { - testCases := [...]struct { - orderNumber int8 - orderColumn string - }{ - 0: {0, "u.full_name"}, - 1: {1, "u.full_name"}, - 2: {2, "u.email"}, - 3: {3, "pm.created_at"}, - 4: {4, "u.full_name"}, - } +func TestGetPagedWithInvitationsByProjectID(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + membersDB := db.Console().ProjectMembers() - for _, tc := range testCases { - assert.Equal(t, tc.orderColumn, sanitizedOrderColumnName(console.ProjectMemberOrder(tc.orderNumber))) - } + projectID := testrand.UUID() + _, err := db.Console().Projects().Insert(ctx, &console.Project{ID: projectID}) + require.NoError(t, err) + + memberUser, err := db.Console().Users().Insert(ctx, &console.User{ + FullName: "Alice", + Email: "alice@mail.test", + ID: testrand.UUID(), + PasswordHash: testrand.Bytes(8), + }) + require.NoError(t, err) + _, err = db.Console().ProjectMembers().Insert(ctx, memberUser.ID, projectID) + require.NoError(t, err) + + _, err = db.Console().ProjectInvitations().Insert(ctx, &console.ProjectInvitation{ + ProjectID: projectID, + Email: "bob@mail.test", + }) + require.NoError(t, err) + + t.Run("paging", func(t *testing.T) { + ctx := testcontext.New(t) + + for _, tt := range []struct { + limit uint + page uint + expectedCount int + }{ + {limit: 2, page: 1, expectedCount: 2}, + {limit: 1, page: 1, expectedCount: 1}, + {limit: 1, page: 2, expectedCount: 1}, + } { + cursor := console.ProjectMembersCursor{Limit: tt.limit, Page: tt.page} + page, err := membersDB.GetPagedWithInvitationsByProjectID(ctx, projectID, cursor) + require.NoError(t, err) + require.Equal(t, tt.expectedCount, len(page.ProjectInvitations)+len(page.ProjectMembers), + fmt.Sprintf("error occurred with limit %d, page %d", tt.limit, tt.page)) + } + + _, err = membersDB.GetPagedWithInvitationsByProjectID(ctx, projectID, console.ProjectMembersCursor{Limit: 1, Page: 3}) + require.Error(t, err) + }) + + t.Run("search", func(t *testing.T) { + ctx := testcontext.New(t) + + for _, tt := range []struct { + search string + expectMembers bool + expectInvites bool + }{ + {search: "aLiCe", expectMembers: true}, + {search: "@ test", expectMembers: true, expectInvites: true}, + {search: "bad"}, + } { + errMsg := "unexpected result for search '" + tt.search + "'" + + cursor := console.ProjectMembersCursor{Search: tt.search, Limit: 2, Page: 1} + page, err := membersDB.GetPagedWithInvitationsByProjectID(ctx, projectID, cursor) + require.NoError(t, err, errMsg) + + if tt.expectMembers { + require.NotEmpty(t, page.ProjectMembers, errMsg) + } else { + require.Empty(t, page.ProjectMembers, errMsg) + } + + if tt.expectInvites { + require.NotEmpty(t, page.ProjectInvitations, errMsg) + } else { + require.Empty(t, page.ProjectInvitations, errMsg) + } + } + }) + + t.Run("ordering", func(t *testing.T) { + ctx := testcontext.New(t) + + projectID := testrand.UUID() + _, err := db.Console().Projects().Insert(ctx, &console.Project{ID: projectID}) + require.NoError(t, err) + + var memberIDs []uuid.UUID + for i := 0; i < 3; i++ { + id := uuid.UUID{} + id[len(id)-1] = byte(i + 1) + memberIDs = append(memberIDs, id) + + user := console.User{ + FullName: fmt.Sprintf("%d", i), + Email: fmt.Sprintf("%d@mail.test", (i+2)%3), + ID: id, + PasswordHash: testrand.Bytes(8), + } + + _, err := db.Console().Users().Insert(ctx, &user) + require.NoError(t, err) + + _, err = db.Console().ProjectMembers().Insert(ctx, user.ID, projectID) + require.NoError(t, err) + + result, err := db.Testing().RawDB().ExecContext(ctx, + "UPDATE project_members SET created_at = $1 WHERE member_id = $2", + time.Time{}.Add(time.Duration((i+1)%3)*time.Hour), id, + ) + require.NoError(t, err) + + count, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, count) + } + + for _, tt := range []struct { + order console.ProjectMemberOrder + memberIDs []uuid.UUID + }{ + { + order: console.Name, + memberIDs: []uuid.UUID{memberIDs[0], memberIDs[1], memberIDs[2]}, + }, { + order: console.Email, + memberIDs: []uuid.UUID{memberIDs[1], memberIDs[2], memberIDs[0]}, + }, { + order: console.Created, + memberIDs: []uuid.UUID{memberIDs[2], memberIDs[0], memberIDs[1]}, + }, + } { + errMsg := func(cursor console.ProjectMembersCursor) string { + return fmt.Sprintf("unexpected result when ordering by %s, %s", + []string{"name", "email", "creation date"}[cursor.Order-1], + []string{"ascending", "descending"}[cursor.OrderDirection-1]) + } + + getIDsFromDB := func(cursor console.ProjectMembersCursor) (ids []uuid.UUID) { + page, err := membersDB.GetPagedWithInvitationsByProjectID(ctx, projectID, cursor) + require.NoError(t, err, errMsg(cursor)) + for _, member := range page.ProjectMembers { + ids = append(ids, member.MemberID) + } + return ids + } + + cursor := console.ProjectMembersCursor{ + Limit: uint(len(tt.memberIDs)), + Page: 1, Order: tt.order, + OrderDirection: console.Ascending, + } + require.Equal(t, tt.memberIDs, getIDsFromDB(cursor), errMsg(cursor)) + + cursor.OrderDirection = console.Descending + var reverseMemberIDs []uuid.UUID + for i := len(tt.memberIDs) - 1; i >= 0; i-- { + reverseMemberIDs = append(reverseMemberIDs, tt.memberIDs[i]) + } + require.Equal(t, reverseMemberIDs, getIDsFromDB(cursor), errMsg(cursor)) + } + }) + }) } diff --git a/satellite/satellitedb/util.go b/satellite/satellitedb/util.go index 49d664e54..54d305dcc 100644 --- a/satellite/satellitedb/util.go +++ b/satellite/satellitedb/util.go @@ -6,6 +6,7 @@ package satellitedb import ( "github.com/zeebo/errs" + "storj.io/common/uuid" "storj.io/private/tagsql" ) @@ -60,3 +61,26 @@ func convertSliceWithErrors[In, Out any](xs []In, fn func(In) (Out, error)) ([]O } return rs, errs } + +// NullUUID represents a nullable uuid.UUID that can be used as an SQL scan destination. +type NullUUID struct { + UUID uuid.UUID + Valid bool +} + +// Scan implements the sql.Scanner interface. +// It scans a value from the database into the NullUUID. +func (nu *NullUUID) Scan(value interface{}) error { + if value == nil { + nu.UUID = uuid.UUID{} + nu.Valid = false + return nil + } + valueBytes, ok := value.([]byte) + if !ok { + return errs.New("invalid UUID type") + } + copy(nu.UUID[:], valueBytes) + nu.Valid = true + return nil +}