diff --git a/pkg/satellite/database.go b/pkg/satellite/database.go index 8ac18a9ce..8f5879f46 100644 --- a/pkg/satellite/database.go +++ b/pkg/satellite/database.go @@ -5,12 +5,14 @@ package satellite // DB contains access to different satellite databases type DB interface { - // Users is getter for Users repository + // Users is a getter for Users repository Users() Users - // Companies is getter for Companies repository + // Companies is a getter for Companies repository Companies() Companies - // Projects is getter for Projects repository + // Projects is a getter for Projects repository Projects() Projects + // ProjectMembers is a getter for ProjectMembers repository + ProjectMembers() ProjectMembers // CreateTables is a method for creating all tables for satellitedb CreateTables() error diff --git a/pkg/satellite/projectmembers.go b/pkg/satellite/projectmembers.go new file mode 100644 index 000000000..745d77b27 --- /dev/null +++ b/pkg/satellite/projectmembers.go @@ -0,0 +1,42 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellite + +import ( + "context" + "time" + + "github.com/skyrings/skyring-common/tools/uuid" +) + +// ProjectMembers exposes methods to manage ProjectMembers table in database. +// TODO: some methods will be removed, some - added +type ProjectMembers interface { + // GetAll is a method for querying all project members from the database. + GetAll(ctx context.Context) ([]ProjectMember, error) + // GetByMemberID is a method for querying project member from the database by memberID. + GetByMemberID(ctx context.Context, memberID uuid.UUID) (*ProjectMember, error) + // GetByProjectID is a method for querying project members from the database by projectID. + GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectMember, error) + // Get is a method for querying project member from the database by id. + Get(ctx context.Context, id uuid.UUID) (*ProjectMember, error) + // Insert is a method for inserting project member into the database. + Insert(ctx context.Context, memberID, projectID uuid.UUID) (*ProjectMember, error) + // Delete is a method for deleting project member by Id from the database. + Delete(ctx context.Context, id uuid.UUID) error + // Update is a method for updating project member entity. + Update(ctx context.Context, projectMember *ProjectMember) error +} + +// ProjectMember is a database object that describes ProjectMember entity. +type ProjectMember struct { + ID uuid.UUID + + // FK on Users table. + MemberID uuid.UUID + // FK on Projects table. + ProjectID uuid.UUID + + CreatedAt time.Time +} diff --git a/pkg/satellite/projects.go b/pkg/satellite/projects.go index 60edc7482..d6438ab0b 100644 --- a/pkg/satellite/projects.go +++ b/pkg/satellite/projects.go @@ -34,8 +34,8 @@ type Project struct { Name string Description string - // Indicates if user accepted terms and conditions during project creation. - IsAgreedWithTerms bool + // stores last accepted version of terms of use. + TermsAccepted int CreatedAt time.Time } diff --git a/pkg/satellite/satellitedb/companies.go b/pkg/satellite/satellitedb/companies.go index d27bd69d0..42126413a 100644 --- a/pkg/satellite/satellitedb/companies.go +++ b/pkg/satellite/satellitedb/companies.go @@ -96,7 +96,7 @@ func companyFromDBX(company *dbx.Company) (*satellite.Company, error) { return nil, err } - comp := &satellite.Company{ + return &satellite.Company{ ID: id, UserID: userID, Name: company.Name, @@ -106,9 +106,7 @@ func companyFromDBX(company *dbx.Company) (*satellite.Company, error) { State: company.State, PostalCode: company.PostalCode, CreatedAt: company.CreatedAt, - } - - return comp, nil + }, nil } // getCompanyUpdateFields is used to generate company update fields diff --git a/pkg/satellite/satellitedb/companies_test.go b/pkg/satellite/satellitedb/companies_test.go index 6eca0a1df..82ae5eaf2 100644 --- a/pkg/satellite/satellitedb/companies_test.go +++ b/pkg/satellite/satellitedb/companies_test.go @@ -46,14 +46,14 @@ func TestCompanyRepository(t *testing.T) { // to test with real db3 file use this connection string - "../db/accountdb.db3" db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") if err != nil { - assert.NoError(t, err) + t.Fatal(err) } defer ctx.Check(db.Close) // creating tables err = db.CreateTables() if err != nil { - assert.NoError(t, err) + t.Fatal(err) } // repositories diff --git a/pkg/satellite/satellitedb/db.go b/pkg/satellite/satellitedb/db.go index 8ae348084..411de4160 100644 --- a/pkg/satellite/satellitedb/db.go +++ b/pkg/satellite/satellitedb/db.go @@ -30,21 +30,26 @@ func New(driver, source string) (satellite.DB, error) { return database, nil } -// Users is getter for Users repository +// Users is getter a for Users repository func (db *Database) Users() satellite.Users { return &users{db.db} } -// Companies is getter for Companies repository +// Companies is a getter for Companies repository func (db *Database) Companies() satellite.Companies { return &companies{db.db} } -// Projects is getter for Projects repository +// Projects is a getter for Projects repository func (db *Database) Projects() satellite.Projects { return &projects{db.db} } +// ProjectMembers is a getter for ProjectMembers repository +func (db *Database) ProjectMembers() satellite.ProjectMembers { + return &projectMembers{db.db} +} + // CreateTables is a method for creating all tables for satellitedb func (db *Database) CreateTables() error { return migrate.Create("satellitedb", db.db) diff --git a/pkg/satellite/satellitedb/dbx/satellitedb.dbx b/pkg/satellite/satellitedb/dbx/satellitedb.dbx index 392f41744..b9665a1d0 100644 --- a/pkg/satellite/satellitedb/dbx/satellitedb.dbx +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx @@ -58,14 +58,15 @@ delete company ( where company.id = ? ) model project ( key id - field id blob - field owner_id user.id setnull ( nullable, updatable ) + field id blob + field owner_id user.id setnull ( nullable, updatable ) - field name text ( updatable ) - field description text ( updatable ) - field is_agreed_with_terms bool ( updatable ) + field name text ( updatable ) + field description text ( updatable ) + // stores last accepted version of terms of use + field terms_accepted int ( updatable ) - field created_at timestamp ( autoinsert ) + field created_at timestamp ( autoinsert ) ) read all ( select project) @@ -79,4 +80,31 @@ read all ( ) create project ( ) update project ( where project.id = ? ) -delete project ( where project.id = ? ) \ No newline at end of file +delete project ( where project.id = ? ) + +model project_member ( + key id + + field id blob + field member_id user.id cascade + field project_id project.id cascade ( updatable ) + + field created_at timestamp ( autoinsert ) +) + +read all ( select project_member) +read all ( + select project_member + where project_member.project_id = ? +) +read one ( + select project_member + where project_member.member_id = ? +) +read one ( + select project_member + where project_member.id = ? +) +create project_member ( ) +update project_member ( where project_member.id = ? ) +delete project_member ( where project_member.id = ? ) \ No newline at end of file diff --git a/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go index 8e011bc01..887742f2a 100644 --- a/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -293,7 +293,14 @@ CREATE TABLE projects ( owner_id BLOB REFERENCES users( id ) ON DELETE SET NULL, name TEXT NOT NULL, description TEXT NOT NULL, - is_agreed_with_terms INTEGER NOT NULL, + terms_accepted INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL, + PRIMARY KEY ( id ) +); +CREATE TABLE project_members ( + id BLOB NOT NULL, + member_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + project_id BLOB NOT NULL REFERENCES projects( id ) ON DELETE CASCADE, created_at TIMESTAMP NOT NULL, PRIMARY KEY ( id ) );` @@ -671,12 +678,12 @@ func (f Company_CreatedAt_Field) value() interface{} { func (Company_CreatedAt_Field) _Column() string { return "created_at" } type Project struct { - Id []byte - OwnerId []byte - Name string - Description string - IsAgreedWithTerms bool - CreatedAt time.Time + Id []byte + OwnerId []byte + Name string + Description string + TermsAccepted int + CreatedAt time.Time } func (Project) _Table() string { return "projects" } @@ -686,10 +693,10 @@ type Project_Create_Fields struct { } type Project_Update_Fields struct { - OwnerId Project_OwnerId_Field - Name Project_Name_Field - Description Project_Description_Field - IsAgreedWithTerms Project_IsAgreedWithTerms_Field + OwnerId Project_OwnerId_Field + Name Project_Name_Field + Description Project_Description_Field + TermsAccepted Project_TermsAccepted_Field } type Project_Id_Field struct { @@ -777,23 +784,23 @@ func (f Project_Description_Field) value() interface{} { func (Project_Description_Field) _Column() string { return "description" } -type Project_IsAgreedWithTerms_Field struct { +type Project_TermsAccepted_Field struct { _set bool - _value bool + _value int } -func Project_IsAgreedWithTerms(v bool) Project_IsAgreedWithTerms_Field { - return Project_IsAgreedWithTerms_Field{_set: true, _value: v} +func Project_TermsAccepted(v int) Project_TermsAccepted_Field { + return Project_TermsAccepted_Field{_set: true, _value: v} } -func (f Project_IsAgreedWithTerms_Field) value() interface{} { +func (f Project_TermsAccepted_Field) value() interface{} { if !f._set { return nil } return f._value } -func (Project_IsAgreedWithTerms_Field) _Column() string { return "is_agreed_with_terms" } +func (Project_TermsAccepted_Field) _Column() string { return "terms_accepted" } type Project_CreatedAt_Field struct { _set bool @@ -813,6 +820,91 @@ func (f Project_CreatedAt_Field) value() interface{} { func (Project_CreatedAt_Field) _Column() string { return "created_at" } +type ProjectMember struct { + Id []byte + MemberId []byte + ProjectId []byte + CreatedAt time.Time +} + +func (ProjectMember) _Table() string { return "project_members" } + +type ProjectMember_Update_Fields struct { + ProjectId ProjectMember_ProjectId_Field +} + +type ProjectMember_Id_Field struct { + _set bool + _value []byte +} + +func ProjectMember_Id(v []byte) ProjectMember_Id_Field { + return ProjectMember_Id_Field{_set: true, _value: v} +} + +func (f ProjectMember_Id_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (ProjectMember_Id_Field) _Column() string { return "id" } + +type ProjectMember_MemberId_Field struct { + _set bool + _value []byte +} + +func ProjectMember_MemberId(v []byte) ProjectMember_MemberId_Field { + return ProjectMember_MemberId_Field{_set: true, _value: v} +} + +func (f ProjectMember_MemberId_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (ProjectMember_MemberId_Field) _Column() string { return "member_id" } + +type ProjectMember_ProjectId_Field struct { + _set bool + _value []byte +} + +func ProjectMember_ProjectId(v []byte) ProjectMember_ProjectId_Field { + return ProjectMember_ProjectId_Field{_set: true, _value: v} +} + +func (f ProjectMember_ProjectId_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (ProjectMember_ProjectId_Field) _Column() string { return "project_id" } + +type ProjectMember_CreatedAt_Field struct { + _set bool + _value time.Time +} + +func ProjectMember_CreatedAt(v time.Time) ProjectMember_CreatedAt_Field { + return ProjectMember_CreatedAt_Field{_set: true, _value: v} +} + +func (f ProjectMember_CreatedAt_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (ProjectMember_CreatedAt_Field) _Column() string { return "created_at" } + func toUTC(t time.Time) time.Time { return t.UTC() } @@ -1057,7 +1149,7 @@ func (obj *sqlite3Impl) Create_Project(ctx context.Context, project_id Project_Id_Field, project_name Project_Name_Field, project_description Project_Description_Field, - project_is_agreed_with_terms Project_IsAgreedWithTerms_Field, + project_terms_accepted Project_TermsAccepted_Field, optional Project_Create_Fields) ( project *Project, err error) { @@ -1066,15 +1158,15 @@ func (obj *sqlite3Impl) Create_Project(ctx context.Context, __owner_id_val := optional.OwnerId.value() __name_val := project_name.value() __description_val := project_description.value() - __is_agreed_with_terms_val := project_is_agreed_with_terms.value() + __terms_accepted_val := project_terms_accepted.value() __created_at_val := __now - var __embed_stmt = __sqlbundle_Literal("INSERT INTO projects ( id, owner_id, name, description, is_agreed_with_terms, created_at ) VALUES ( ?, ?, ?, ?, ?, ? )") + var __embed_stmt = __sqlbundle_Literal("INSERT INTO projects ( id, owner_id, name, description, terms_accepted, created_at ) VALUES ( ?, ?, ?, ?, ?, ? )") var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) - obj.logStmt(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val) + obj.logStmt(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __terms_accepted_val, __created_at_val) - __res, err := obj.driver.Exec(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val) + __res, err := obj.driver.Exec(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __terms_accepted_val, __created_at_val) if err != nil { return nil, obj.makeErr(err) } @@ -1086,6 +1178,35 @@ func (obj *sqlite3Impl) Create_Project(ctx context.Context, } +func (obj *sqlite3Impl) Create_ProjectMember(ctx context.Context, + project_member_id ProjectMember_Id_Field, + project_member_member_id ProjectMember_MemberId_Field, + project_member_project_id ProjectMember_ProjectId_Field) ( + project_member *ProjectMember, err error) { + + __now := obj.db.Hooks.Now().UTC() + __id_val := project_member_id.value() + __member_id_val := project_member_member_id.value() + __project_id_val := project_member_project_id.value() + __created_at_val := __now + + var __embed_stmt = __sqlbundle_Literal("INSERT INTO project_members ( id, member_id, project_id, created_at ) VALUES ( ?, ?, ?, ? )") + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __id_val, __member_id_val, __project_id_val, __created_at_val) + + __res, err := obj.driver.Exec(__stmt, __id_val, __member_id_val, __project_id_val, __created_at_val) + if err != nil { + return nil, obj.makeErr(err) + } + __pk, err := __res.LastInsertId() + if err != nil { + return nil, obj.makeErr(err) + } + return obj.getLastProjectMember(ctx, __pk) + +} + func (obj *sqlite3Impl) Get_User_By_Email_And_PasswordHash(ctx context.Context, user_email User_Email_Field, user_password_hash User_PasswordHash_Field) ( @@ -1196,7 +1317,7 @@ func (obj *sqlite3Impl) Get_Company_By_Id(ctx context.Context, func (obj *sqlite3Impl) All_Project(ctx context.Context) ( rows []*Project, err error) { - var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects") + var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.terms_accepted, projects.created_at FROM projects") var __values []interface{} __values = append(__values) @@ -1212,7 +1333,7 @@ func (obj *sqlite3Impl) All_Project(ctx context.Context) ( for __rows.Next() { project := &Project{} - err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt) + err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.TermsAccepted, &project.CreatedAt) if err != nil { return nil, obj.makeErr(err) } @@ -1229,7 +1350,7 @@ func (obj *sqlite3Impl) Get_Project_By_Id(ctx context.Context, project_id Project_Id_Field) ( project *Project, err error) { - var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?") + var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.terms_accepted, projects.created_at FROM projects WHERE projects.id = ?") var __values []interface{} __values = append(__values, project_id.value()) @@ -1238,7 +1359,7 @@ func (obj *sqlite3Impl) Get_Project_By_Id(ctx context.Context, obj.logStmt(__stmt, __values...) project = &Project{} - err = obj.driver.QueryRow(__stmt, __values...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt) + err = obj.driver.QueryRow(__stmt, __values...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.TermsAccepted, &project.CreatedAt) if err != nil { return nil, obj.makeErr(err) } @@ -1252,7 +1373,7 @@ func (obj *sqlite3Impl) All_Project_By_OwnerId(ctx context.Context, var __cond_0 = &__sqlbundle_Condition{Left: "projects.owner_id", Equal: true, Right: "?", Null: true} - var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE "), __cond_0}} + var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.terms_accepted, projects.created_at FROM projects WHERE "), __cond_0}} var __values []interface{} __values = append(__values) @@ -1273,7 +1394,7 @@ func (obj *sqlite3Impl) All_Project_By_OwnerId(ctx context.Context, for __rows.Next() { project := &Project{} - err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt) + err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.TermsAccepted, &project.CreatedAt) if err != nil { return nil, obj.makeErr(err) } @@ -1286,6 +1407,135 @@ func (obj *sqlite3Impl) All_Project_By_OwnerId(ctx context.Context, } +func (obj *sqlite3Impl) All_ProjectMember(ctx context.Context) ( + rows []*ProjectMember, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members") + + var __values []interface{} + __values = append(__values) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __rows, err := obj.driver.Query(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + defer __rows.Close() + + for __rows.Next() { + project_member := &ProjectMember{} + err = __rows.Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + rows = append(rows, project_member) + } + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + return rows, nil + +} + +func (obj *sqlite3Impl) All_ProjectMember_By_ProjectId(ctx context.Context, + project_member_project_id ProjectMember_ProjectId_Field) ( + rows []*ProjectMember, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members WHERE project_members.project_id = ?") + + var __values []interface{} + __values = append(__values, project_member_project_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __rows, err := obj.driver.Query(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + defer __rows.Close() + + for __rows.Next() { + project_member := &ProjectMember{} + err = __rows.Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + rows = append(rows, project_member) + } + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + return rows, nil + +} + +func (obj *sqlite3Impl) Get_ProjectMember_By_MemberId(ctx context.Context, + project_member_member_id ProjectMember_MemberId_Field) ( + project_member *ProjectMember, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members WHERE project_members.member_id = ? LIMIT 2") + + var __values []interface{} + __values = append(__values, project_member_member_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __rows, err := obj.driver.Query(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + defer __rows.Close() + + if !__rows.Next() { + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + return nil, makeErr(sql.ErrNoRows) + } + + project_member = &ProjectMember{} + err = __rows.Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + + if __rows.Next() { + return nil, tooManyRows("ProjectMember_By_MemberId") + } + + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + + return project_member, nil + +} + +func (obj *sqlite3Impl) Get_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + project_member *ProjectMember, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members WHERE project_members.id = ?") + + var __values []interface{} + __values = append(__values, project_member_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + project_member = &ProjectMember{} + err = obj.driver.QueryRow(__stmt, __values...).Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + return project_member, nil + +} + func (obj *sqlite3Impl) Update_User_By_Id(ctx context.Context, user_id User_Id_Field, update User_Update_Fields) ( @@ -1453,9 +1703,9 @@ func (obj *sqlite3Impl) Update_Project_By_Id(ctx context.Context, __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("description = ?")) } - if update.IsAgreedWithTerms._set { - __values = append(__values, update.IsAgreedWithTerms.value()) - __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("is_agreed_with_terms = ?")) + if update.TermsAccepted._set { + __values = append(__values, update.TermsAccepted.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("terms_accepted = ?")) } if len(__sets_sql.SQLs) == 0 { @@ -1476,12 +1726,12 @@ func (obj *sqlite3Impl) Update_Project_By_Id(ctx context.Context, return nil, obj.makeErr(err) } - var __embed_stmt_get = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?") + var __embed_stmt_get = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.terms_accepted, projects.created_at FROM projects WHERE projects.id = ?") var __stmt_get = __sqlbundle_Render(obj.dialect, __embed_stmt_get) obj.logStmt("(IMPLIED) "+__stmt_get, __args...) - err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt) + err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.TermsAccepted, &project.CreatedAt) if err == sql.ErrNoRows { return nil, nil } @@ -1491,6 +1741,56 @@ func (obj *sqlite3Impl) Update_Project_By_Id(ctx context.Context, return project, nil } +func (obj *sqlite3Impl) Update_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field, + update ProjectMember_Update_Fields) ( + project_member *ProjectMember, err error) { + var __sets = &__sqlbundle_Hole{} + + var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE project_members SET "), __sets, __sqlbundle_Literal(" WHERE project_members.id = ?")}} + + __sets_sql := __sqlbundle_Literals{Join: ", "} + var __values []interface{} + var __args []interface{} + + if update.ProjectId._set { + __values = append(__values, update.ProjectId.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("project_id = ?")) + } + + if len(__sets_sql.SQLs) == 0 { + return nil, emptyUpdate() + } + + __args = append(__args, project_member_id.value()) + + __values = append(__values, __args...) + __sets.SQL = __sets_sql + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + project_member = &ProjectMember{} + _, err = obj.driver.Exec(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + + var __embed_stmt_get = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members WHERE project_members.id = ?") + + var __stmt_get = __sqlbundle_Render(obj.dialect, __embed_stmt_get) + obj.logStmt("(IMPLIED) "+__stmt_get, __args...) + + err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, obj.makeErr(err) + } + return project_member, nil +} + func (obj *sqlite3Impl) Delete_User_By_Id(ctx context.Context, user_id User_Id_Field) ( deleted bool, err error) { @@ -1569,6 +1869,32 @@ func (obj *sqlite3Impl) Delete_Project_By_Id(ctx context.Context, } +func (obj *sqlite3Impl) Delete_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + deleted bool, err error) { + + var __embed_stmt = __sqlbundle_Literal("DELETE FROM project_members WHERE project_members.id = ?") + + var __values []interface{} + __values = append(__values, project_member_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __res, err := obj.driver.Exec(__stmt, __values...) + if err != nil { + return false, obj.makeErr(err) + } + + __count, err := __res.RowsAffected() + if err != nil { + return false, obj.makeErr(err) + } + + return __count > 0, nil + +} + func (obj *sqlite3Impl) getLastUser(ctx context.Context, pk int64) ( user *User, err error) { @@ -1609,13 +1935,13 @@ func (obj *sqlite3Impl) getLastProject(ctx context.Context, pk int64) ( project *Project, err error) { - var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE _rowid_ = ?") + var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.terms_accepted, projects.created_at FROM projects WHERE _rowid_ = ?") var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) obj.logStmt(__stmt, pk) project = &Project{} - err = obj.driver.QueryRow(__stmt, pk).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt) + err = obj.driver.QueryRow(__stmt, pk).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.TermsAccepted, &project.CreatedAt) if err != nil { return nil, obj.makeErr(err) } @@ -1623,6 +1949,24 @@ func (obj *sqlite3Impl) getLastProject(ctx context.Context, } +func (obj *sqlite3Impl) getLastProjectMember(ctx context.Context, + pk int64) ( + project_member *ProjectMember, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT project_members.id, project_members.member_id, project_members.project_id, project_members.created_at FROM project_members WHERE _rowid_ = ?") + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, pk) + + project_member = &ProjectMember{} + err = obj.driver.QueryRow(__stmt, pk).Scan(&project_member.Id, &project_member.MemberId, &project_member.ProjectId, &project_member.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + return project_member, nil + +} + func (impl sqlite3Impl) isConstraintError(err error) ( constraint string, ok bool) { if e, ok := err.(sqlite3.Error); ok { @@ -1641,6 +1985,16 @@ func (impl sqlite3Impl) isConstraintError(err error) ( func (obj *sqlite3Impl) deleteAll(ctx context.Context) (count int64, err error) { var __res sql.Result var __count int64 + __res, err = obj.driver.Exec("DELETE FROM project_members;") + if err != nil { + return 0, obj.makeErr(err) + } + + __count, err = __res.RowsAffected() + if err != nil { + return 0, obj.makeErr(err) + } + count += __count __res, err = obj.driver.Exec("DELETE FROM projects;") if err != nil { return 0, obj.makeErr(err) @@ -1727,6 +2081,25 @@ func (rx *Rx) All_Project(ctx context.Context) ( return tx.All_Project(ctx) } +func (rx *Rx) All_ProjectMember(ctx context.Context) ( + rows []*ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.All_ProjectMember(ctx) +} + +func (rx *Rx) All_ProjectMember_By_ProjectId(ctx context.Context, + project_member_project_id ProjectMember_ProjectId_Field) ( + rows []*ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.All_ProjectMember_By_ProjectId(ctx, project_member_project_id) +} + func (rx *Rx) All_Project_By_OwnerId(ctx context.Context, project_owner_id Project_OwnerId_Field) ( rows []*Project, err error) { @@ -1759,14 +2132,27 @@ func (rx *Rx) Create_Project(ctx context.Context, project_id Project_Id_Field, project_name Project_Name_Field, project_description Project_Description_Field, - project_is_agreed_with_terms Project_IsAgreedWithTerms_Field, + project_terms_accepted Project_TermsAccepted_Field, optional Project_Create_Fields) ( project *Project, err error) { var tx *Tx if tx, err = rx.getTx(ctx); err != nil { return } - return tx.Create_Project(ctx, project_id, project_name, project_description, project_is_agreed_with_terms, optional) + return tx.Create_Project(ctx, project_id, project_name, project_description, project_terms_accepted, optional) + +} + +func (rx *Rx) Create_ProjectMember(ctx context.Context, + project_member_id ProjectMember_Id_Field, + project_member_member_id ProjectMember_MemberId_Field, + project_member_project_id ProjectMember_ProjectId_Field) ( + project_member *ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Create_ProjectMember(ctx, project_member_id, project_member_member_id, project_member_project_id) } @@ -1795,6 +2181,16 @@ func (rx *Rx) Delete_Company_By_Id(ctx context.Context, return tx.Delete_Company_By_Id(ctx, company_id) } +func (rx *Rx) Delete_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + deleted bool, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Delete_ProjectMember_By_Id(ctx, project_member_id) +} + func (rx *Rx) Delete_Project_By_Id(ctx context.Context, project_id Project_Id_Field) ( deleted bool, err error) { @@ -1835,6 +2231,26 @@ func (rx *Rx) Get_Company_By_UserId(ctx context.Context, return tx.Get_Company_By_UserId(ctx, company_user_id) } +func (rx *Rx) Get_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + project_member *ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Get_ProjectMember_By_Id(ctx, project_member_id) +} + +func (rx *Rx) Get_ProjectMember_By_MemberId(ctx context.Context, + project_member_member_id ProjectMember_MemberId_Field) ( + project_member *ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Get_ProjectMember_By_MemberId(ctx, project_member_member_id) +} + func (rx *Rx) Get_Project_By_Id(ctx context.Context, project_id Project_Id_Field) ( project *Project, err error) { @@ -1877,6 +2293,17 @@ func (rx *Rx) Update_Company_By_Id(ctx context.Context, return tx.Update_Company_By_Id(ctx, company_id, update) } +func (rx *Rx) Update_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field, + update ProjectMember_Update_Fields) ( + project_member *ProjectMember, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Update_ProjectMember_By_Id(ctx, project_member_id, update) +} + func (rx *Rx) Update_Project_By_Id(ctx context.Context, project_id Project_Id_Field, update Project_Update_Fields) ( @@ -1903,6 +2330,13 @@ type Methods interface { All_Project(ctx context.Context) ( rows []*Project, err error) + All_ProjectMember(ctx context.Context) ( + rows []*ProjectMember, err error) + + All_ProjectMember_By_ProjectId(ctx context.Context, + project_member_project_id ProjectMember_ProjectId_Field) ( + rows []*ProjectMember, err error) + All_Project_By_OwnerId(ctx context.Context, project_owner_id Project_OwnerId_Field) ( rows []*Project, err error) @@ -1922,10 +2356,16 @@ type Methods interface { project_id Project_Id_Field, project_name Project_Name_Field, project_description Project_Description_Field, - project_is_agreed_with_terms Project_IsAgreedWithTerms_Field, + project_terms_accepted Project_TermsAccepted_Field, optional Project_Create_Fields) ( project *Project, err error) + Create_ProjectMember(ctx context.Context, + project_member_id ProjectMember_Id_Field, + project_member_member_id ProjectMember_MemberId_Field, + project_member_project_id ProjectMember_ProjectId_Field) ( + project_member *ProjectMember, err error) + Create_User(ctx context.Context, user_id User_Id_Field, user_first_name User_FirstName_Field, @@ -1938,6 +2378,10 @@ type Methods interface { company_id Company_Id_Field) ( deleted bool, err error) + Delete_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + deleted bool, err error) + Delete_Project_By_Id(ctx context.Context, project_id Project_Id_Field) ( deleted bool, err error) @@ -1954,6 +2398,14 @@ type Methods interface { company_user_id Company_UserId_Field) ( company *Company, err error) + Get_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field) ( + project_member *ProjectMember, err error) + + Get_ProjectMember_By_MemberId(ctx context.Context, + project_member_member_id ProjectMember_MemberId_Field) ( + project_member *ProjectMember, err error) + Get_Project_By_Id(ctx context.Context, project_id Project_Id_Field) ( project *Project, err error) @@ -1972,6 +2424,11 @@ type Methods interface { update Company_Update_Fields) ( company *Company, err error) + Update_ProjectMember_By_Id(ctx context.Context, + project_member_id ProjectMember_Id_Field, + update ProjectMember_Update_Fields) ( + project_member *ProjectMember, err error) + Update_Project_By_Id(ctx context.Context, project_id Project_Id_Field, update Project_Update_Fields) ( diff --git a/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql index 27ae9b1e1..45bb343b5 100644 --- a/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql @@ -27,7 +27,14 @@ CREATE TABLE projects ( owner_id BLOB REFERENCES users( id ) ON DELETE SET NULL, name TEXT NOT NULL, description TEXT NOT NULL, - is_agreed_with_terms INTEGER NOT NULL, + terms_accepted INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL, + PRIMARY KEY ( id ) +); +CREATE TABLE project_members ( + id BLOB NOT NULL, + member_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + project_id BLOB NOT NULL REFERENCES projects( id ) ON DELETE CASCADE, created_at TIMESTAMP NOT NULL, PRIMARY KEY ( id ) ); diff --git a/pkg/satellite/satellitedb/projectmembers.go b/pkg/satellite/satellitedb/projectmembers.go new file mode 100644 index 000000000..037c749f3 --- /dev/null +++ b/pkg/satellite/satellitedb/projectmembers.go @@ -0,0 +1,143 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellitedb + +import ( + "context" + + "github.com/skyrings/skyring-common/tools/uuid" + "github.com/zeebo/errs" + "storj.io/storj/pkg/satellite" + "storj.io/storj/pkg/satellite/satellitedb/dbx" + "storj.io/storj/pkg/utils" +) + +// ProjectMembers exposes methods to manage ProjectMembers table in database. +type projectMembers struct { + db *dbx.DB +} + +// GetAll is a method for querying all project members from the database. +func (pm *projectMembers) GetAll(ctx context.Context) ([]satellite.ProjectMember, error) { + projectMembersDbx, err := pm.db.All_ProjectMember(ctx) + if err != nil { + return nil, err + } + + return projectMembersFromDbxSlice(projectMembersDbx) +} + +// GetByMemberID is a method for querying project member from the database by memberID. +func (pm *projectMembers) GetByMemberID(ctx context.Context, memberID uuid.UUID) (*satellite.ProjectMember, error) { + projectMemberDbx, err := pm.db.Get_ProjectMember_By_MemberId(ctx, dbx.ProjectMember_MemberId(memberID[:])) + if err != nil { + return nil, err + } + + return projectMemberFromDBX(projectMemberDbx) +} + +// GetByProjectID is a method for querying project members from the database by projectID. +func (pm *projectMembers) GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]satellite.ProjectMember, error) { + projectMembersDbx, err := pm.db.All_ProjectMember_By_ProjectId(ctx, dbx.ProjectMember_ProjectId(projectID[:])) + if err != nil { + return nil, err + } + + return projectMembersFromDbxSlice(projectMembersDbx) +} + +// Get is a method for querying project member from the database by id. +func (pm *projectMembers) Get(ctx context.Context, id uuid.UUID) (*satellite.ProjectMember, error) { + projectMember, err := pm.db.Get_ProjectMember_By_Id(ctx, dbx.ProjectMember_Id(id[:])) + if err != nil { + return nil, err + } + + return projectMemberFromDBX(projectMember) +} + +// Insert is a method for inserting project member into the database. +func (pm *projectMembers) Insert(ctx context.Context, memberID, projectID uuid.UUID) (*satellite.ProjectMember, error) { + id, err := uuid.New() + if err != nil { + return nil, err + } + + createdProjectMember, err := pm.db.Create_ProjectMember(ctx, + dbx.ProjectMember_Id(id[:]), + dbx.ProjectMember_MemberId(memberID[:]), + dbx.ProjectMember_ProjectId(projectID[:])) + if err != nil { + return nil, err + } + + return projectMemberFromDBX(createdProjectMember) +} + +// Delete is a method for deleting project member by Id from the database. +func (pm *projectMembers) Delete(ctx context.Context, id uuid.UUID) error { + _, err := pm.db.Delete_ProjectMember_By_Id(ctx, dbx.ProjectMember_Id(id[:])) + + return err +} + +// Update is a method for updating project member entity. +func (pm *projectMembers) Update(ctx context.Context, projectMember *satellite.ProjectMember) error { + _, err := pm.db.Update_ProjectMember_By_Id(ctx, + dbx.ProjectMember_Id(projectMember.ID[:]), + dbx.ProjectMember_Update_Fields{ + ProjectId: dbx.ProjectMember_ProjectId(projectMember.ProjectID[:]), + }) + + return err +} + +// projectMemberFromDBX is used for creating ProjectMember entity from autogenerated dbx.ProjectMember struct +func projectMemberFromDBX(projectMember *dbx.ProjectMember) (*satellite.ProjectMember, error) { + if projectMember == nil { + return nil, errs.New("projectMember parameter is nil") + } + + id, err := bytesToUUID(projectMember.Id) + if err != nil { + return nil, err + } + + memberID, err := bytesToUUID(projectMember.MemberId) + if err != nil { + return nil, err + } + + projectID, err := bytesToUUID(projectMember.ProjectId) + if err != nil { + return nil, err + } + + return &satellite.ProjectMember{ + ID: id, + MemberID: memberID, + ProjectID: projectID, + CreatedAt: projectMember.CreatedAt, + }, nil +} + +// projectMembersFromDbxSlice is used for creating []ProjectMember entities from autogenerated []*dbx.ProjectMember struct +func projectMembersFromDbxSlice(projectMembersDbx []*dbx.ProjectMember) ([]satellite.ProjectMember, error) { + var projectMembers []satellite.ProjectMember + var errors []error + + // Generating []dbo from []dbx and collecting all errors + for _, projectMemberDbx := range projectMembersDbx { + projectMember, err := projectMemberFromDBX(projectMemberDbx) + if err != nil { + errors = append(errors, err) + continue + } + + projectMembers = append(projectMembers, *projectMember) + } + + return projectMembers, utils.CombineErrors(errors...) +} diff --git a/pkg/satellite/satellitedb/projectmembers_test.go b/pkg/satellite/satellitedb/projectmembers_test.go new file mode 100644 index 000000000..f9913989c --- /dev/null +++ b/pkg/satellite/satellitedb/projectmembers_test.go @@ -0,0 +1,239 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellitedb + +import ( + "context" + "testing" + + "github.com/skyrings/skyring-common/tools/uuid" + "github.com/stretchr/testify/assert" + "storj.io/storj/internal/testcontext" + "storj.io/storj/pkg/satellite" +) + +func TestProjectMembersRepository(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + // creating in-memory db and opening connection + db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + defer ctx.Check(db.Close) + + // creating tables + err = db.CreateTables() + if err != nil { + t.Fatal(err) + } + + // repositories + users := db.Users() + projects := db.Projects() + projectMembers := db.ProjectMembers() + + createdUsers, createdProjects := prepareUsersAndProjects(ctx, t, users, projects) + + t.Run("Can't insert projectMember without memberID", func(t *testing.T) { + unexistingUserID, err := uuid.New() + assert.NoError(t, err) + + projMember, err := projectMembers.Insert(ctx, *unexistingUserID, createdProjects[0].ID) + assert.Nil(t, projMember) + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("Can't insert projectMember without projectID", func(t *testing.T) { + unexistingProjectID, err := uuid.New() + assert.NoError(t, err) + + projMember, err := projectMembers.Insert(ctx, createdUsers[0].ID, *unexistingProjectID) + assert.Nil(t, projMember) + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("Insert success", func(t *testing.T) { + projMember1, err := projectMembers.Insert(ctx, createdUsers[0].ID, createdProjects[0].ID) + assert.NotNil(t, projMember1) + assert.Nil(t, err) + assert.NoError(t, err) + + projMember2, err := projectMembers.Insert(ctx, createdUsers[1].ID, createdProjects[0].ID) + assert.NotNil(t, projMember2) + assert.Nil(t, err) + assert.NoError(t, err) + + projMember3, err := projectMembers.Insert(ctx, createdUsers[2].ID, createdProjects[1].ID) + assert.NotNil(t, projMember3) + assert.Nil(t, err) + assert.NoError(t, err) + }) + + t.Run("Get member by memberID success", func(t *testing.T) { + originalMember1 := createdUsers[0] + selectedMember1, err := projectMembers.GetByMemberID(ctx, originalMember1.ID) + + assert.NotNil(t, selectedMember1) + assert.Nil(t, err) + assert.NoError(t, err) + assert.Equal(t, originalMember1.ID, selectedMember1.MemberID) + + originalMember2 := createdUsers[1] + selectedMember2, err := projectMembers.GetByMemberID(ctx, originalMember2.ID) + + assert.NotNil(t, selectedMember2) + assert.Nil(t, err) + assert.NoError(t, err) + assert.Equal(t, originalMember2.ID, selectedMember2.MemberID) + }) + + t.Run("Get member by projectID success", func(t *testing.T) { + originalProject1 := createdProjects[0] + projectMembers1, err := projectMembers.GetByProjectID(ctx, originalProject1.ID) + + assert.NotNil(t, projectMembers1) + assert.Equal(t, 2, len(projectMembers1)) + assert.Nil(t, err) + assert.NoError(t, err) + assert.Equal(t, projectMembers1[0].MemberID, createdUsers[0].ID) + assert.Equal(t, projectMembers1[1].MemberID, createdUsers[1].ID) + + originalProject2 := createdProjects[1] + projectMembers2, err := projectMembers.GetByProjectID(ctx, originalProject2.ID) + + assert.NotNil(t, projectMembers2) + assert.Equal(t, 1, len(projectMembers2)) + assert.Nil(t, err) + assert.NoError(t, err) + assert.Equal(t, projectMembers2[0].MemberID, createdUsers[2].ID) + }) + + t.Run("Get all and get by id success", func(t *testing.T) { + allProjMembers, err := projectMembers.GetAll(ctx) + assert.NotNil(t, allProjMembers) + assert.Equal(t, 3, len(allProjMembers)) + assert.Nil(t, err) + assert.NoError(t, err) + + projMember1, err := projectMembers.Get(ctx, allProjMembers[0].ID) + assert.NotNil(t, projMember1) + assert.Nil(t, err) + assert.NoError(t, err) + + projMember2, err := projectMembers.Get(ctx, allProjMembers[1].ID) + assert.NotNil(t, projMember2) + assert.Nil(t, err) + assert.NoError(t, err) + + projMember3, err := projectMembers.Get(ctx, allProjMembers[2].ID) + assert.NotNil(t, projMember3) + assert.Nil(t, err) + assert.NoError(t, err) + }) + + t.Run("Update success", func(t *testing.T) { + // fetching member of project #2 + members, err := projectMembers.GetByProjectID(ctx, createdProjects[1].ID) + assert.NotNil(t, members) + assert.Equal(t, 1, len(members)) + assert.Nil(t, err) + assert.NoError(t, err) + + // set its proj id to proj1 id + projMemberToUpdate := members[0] + projMemberToUpdate.ProjectID = createdProjects[0].ID + + err = projectMembers.Update(ctx, &projMemberToUpdate) + assert.Nil(t, err) + assert.NoError(t, err) + + // checking that proj 2 has 0 members + members, err = projectMembers.GetByProjectID(ctx, createdProjects[1].ID) + assert.Equal(t, 0, len(members)) + assert.Nil(t, members) + assert.Nil(t, err) + assert.NoError(t, err) + + // checking that proj 1 has 3 members after update + members, err = projectMembers.GetByProjectID(ctx, createdProjects[0].ID) + assert.NotNil(t, members) + assert.Equal(t, 3, len(members)) + assert.Nil(t, err) + assert.NoError(t, err) + }) + + t.Run("Delete success", func(t *testing.T) { + members, err := projectMembers.GetByProjectID(ctx, createdProjects[0].ID) + assert.NotNil(t, members) + assert.Equal(t, 3, len(members)) + assert.Nil(t, err) + assert.NoError(t, err) + + err = projectMembers.Delete(ctx, members[2].ID) + assert.Nil(t, err) + assert.NoError(t, err) + + members, err = projectMembers.GetByProjectID(ctx, createdProjects[0].ID) + assert.NotNil(t, members) + assert.Equal(t, 2, len(members)) + assert.Nil(t, err) + assert.NoError(t, err) + }) +} + +func prepareUsersAndProjects(ctx context.Context, t *testing.T, users satellite.Users, projects satellite.Projects) ([]*satellite.User, []*satellite.Project) { + usersList := []*satellite.User{{ + Email: "email1@ukr.net", + PasswordHash: []byte("some_readable_hash"), + LastName: "LastName", + FirstName: "FirstName", + }, { + Email: "email2@ukr.net", + PasswordHash: []byte("some_readable_hash"), + LastName: "LastName", + FirstName: "FirstName", + }, { + Email: "email3@ukr.net", + PasswordHash: []byte("some_readable_hash"), + LastName: "LastName", + FirstName: "FirstName", + }, + } + + var err error + for i, user := range usersList { + usersList[i], err = users.Insert(ctx, user) + if err != nil { + t.Fatal(err) + } + } + + projectList := []*satellite.Project{ + { + Name: "projName1", + TermsAccepted: 1, + Description: "Test project 1", + OwnerID: &usersList[0].ID, + }, + { + Name: "projName2", + TermsAccepted: 1, + Description: "Test project 1", + OwnerID: &usersList[1].ID, + }, + } + + for i, project := range projectList { + projectList[i], err = projects.Insert(ctx, project) + if err != nil { + t.Fatal(err) + } + } + + return usersList, projectList +} diff --git a/pkg/satellite/satellitedb/projects.go b/pkg/satellite/satellitedb/projects.go index 175896215..a629e7779 100644 --- a/pkg/satellite/satellitedb/projects.go +++ b/pkg/satellite/satellitedb/projects.go @@ -67,7 +67,7 @@ func (projects *projects) Insert(ctx context.Context, project *satellite.Project dbx.Project_Id(projectID[:]), dbx.Project_Name(project.Name), dbx.Project_Description(project.Description), - dbx.Project_IsAgreedWithTerms(project.IsAgreedWithTerms), + dbx.Project_TermsAccepted(project.TermsAccepted), dbx.Project_Create_Fields{ OwnerId: ownerID, }) @@ -91,9 +91,9 @@ func (projects *projects) Update(ctx context.Context, project *satellite.Project _, err := projects.db.Update_Project_By_Id(ctx, dbx.Project_Id(project.ID[:]), dbx.Project_Update_Fields{ - Name: dbx.Project_Name(project.Name), - Description: dbx.Project_Description(project.Description), - IsAgreedWithTerms: dbx.Project_IsAgreedWithTerms(project.IsAgreedWithTerms), + Name: dbx.Project_Name(project.Name), + Description: dbx.Project_Description(project.Description), + TermsAccepted: dbx.Project_TermsAccepted(project.TermsAccepted), }) return err @@ -111,11 +111,11 @@ func projectFromDBX(project *dbx.Project) (*satellite.Project, error) { } u := &satellite.Project{ - ID: id, - Name: project.Name, - Description: project.Description, - IsAgreedWithTerms: project.IsAgreedWithTerms, - CreatedAt: project.CreatedAt, + ID: id, + Name: project.Name, + Description: project.Description, + TermsAccepted: project.TermsAccepted, + CreatedAt: project.CreatedAt, } if project.OwnerId == nil { diff --git a/pkg/satellite/satellitedb/projects_test.go b/pkg/satellite/satellitedb/projects_test.go index 3cf3f2d4d..860e96102 100644 --- a/pkg/satellite/satellitedb/projects_test.go +++ b/pkg/satellite/satellitedb/projects_test.go @@ -39,14 +39,14 @@ func TestProjectsRepository(t *testing.T) { // to test with real db3 file use this connection string - "../db/accountdb.db3" db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") if err != nil { - assert.NoError(t, err) + t.Fatal(err) } defer ctx.Check(db.Close) // creating tables err = db.CreateTables() if err != nil { - assert.NoError(t, err) + t.Fatal(err) } // repositories @@ -57,10 +57,10 @@ func TestProjectsRepository(t *testing.T) { t.Run("Can insert project without owner", func(t *testing.T) { project := &satellite.Project{ - OwnerID: nil, - Name: name, - Description: description, - IsAgreedWithTerms: false, + OwnerID: nil, + Name: name, + Description: description, + TermsAccepted: 1, } createdProject, err := projects.Insert(ctx, project) @@ -84,9 +84,9 @@ func TestProjectsRepository(t *testing.T) { project := &satellite.Project{ OwnerID: &owner.ID, - Name: name, - Description: description, - IsAgreedWithTerms: false, + Name: name, + Description: description, + TermsAccepted: 1, } createdProject, err := projects.Insert(ctx, project) @@ -106,7 +106,7 @@ func TestProjectsRepository(t *testing.T) { assert.Equal(t, projectsByOwnerID[0].OwnerID, &owner.ID) assert.Equal(t, projectsByOwnerID[0].Name, name) assert.Equal(t, projectsByOwnerID[0].Description, description) - assert.Equal(t, projectsByOwnerID[0].IsAgreedWithTerms, false) + assert.Equal(t, projectsByOwnerID[0].TermsAccepted, 1) projectByID, err := projects.Get(ctx, projectsByOwnerID[0].ID) @@ -117,7 +117,7 @@ func TestProjectsRepository(t *testing.T) { assert.Equal(t, projectByID.OwnerID, &owner.ID) assert.Equal(t, projectByID.Name, name) assert.Equal(t, projectByID.Description, description) - assert.Equal(t, projectByID.IsAgreedWithTerms, false) + assert.Equal(t, projectByID.TermsAccepted, 1) }) t.Run("Update project success", func(t *testing.T) { @@ -129,11 +129,11 @@ func TestProjectsRepository(t *testing.T) { // creating new project with updated values newProject := &satellite.Project{ - ID: oldProjects[0].ID, - OwnerID: &owner.ID, - Name: newName, - Description: newDescription, - IsAgreedWithTerms: true, + ID: oldProjects[0].ID, + OwnerID: &owner.ID, + Name: newName, + Description: newDescription, + TermsAccepted: 1, } err = projects.Update(ctx, newProject) @@ -150,7 +150,7 @@ func TestProjectsRepository(t *testing.T) { assert.Equal(t, newProject.OwnerID, &owner.ID) assert.Equal(t, newProject.Name, newName) assert.Equal(t, newProject.Description, newDescription) - assert.Equal(t, newProject.IsAgreedWithTerms, true) + assert.Equal(t, newProject.TermsAccepted, 1) }) t.Run("Delete project success", func(t *testing.T) { @@ -178,10 +178,10 @@ func TestProjectsRepository(t *testing.T) { assert.Equal(t, len(allProjects), 1) newProject := &satellite.Project{ - OwnerID: &owner.ID, - Description: description, - Name: name, - IsAgreedWithTerms: true, + OwnerID: &owner.ID, + Description: description, + Name: name, + TermsAccepted: 1, } _, err = projects.Insert(ctx, newProject) @@ -196,10 +196,10 @@ func TestProjectsRepository(t *testing.T) { assert.Equal(t, len(allProjects), 2) newProject2 := &satellite.Project{ - OwnerID: &owner.ID, - Description: description, - Name: name, - IsAgreedWithTerms: true, + OwnerID: &owner.ID, + Description: description, + Name: name, + TermsAccepted: 1, } _, err = projects.Insert(ctx, newProject2) diff --git a/pkg/satellite/satellitedb/users.go b/pkg/satellite/satellitedb/users.go index b1d0775d0..162245759 100644 --- a/pkg/satellite/satellitedb/users.go +++ b/pkg/satellite/satellitedb/users.go @@ -95,14 +95,12 @@ func userFromDBX(user *dbx.User) (*satellite.User, error) { return nil, err } - u := &satellite.User{ + return &satellite.User{ ID: id, FirstName: user.FirstName, LastName: user.LastName, Email: user.Email, PasswordHash: user.PasswordHash, CreatedAt: user.CreatedAt, - } - - return u, nil + }, nil } diff --git a/pkg/satellite/satellitedb/users_test.go b/pkg/satellite/satellitedb/users_test.go index 2f3157019..d63ded9dc 100644 --- a/pkg/satellite/satellitedb/users_test.go +++ b/pkg/satellite/satellitedb/users_test.go @@ -36,14 +36,14 @@ func TestUserRepository(t *testing.T) { // to test with real db3 file use this connection string - "../db/accountdb.db3" db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") if err != nil { - assert.NoError(t, err) + t.Fatal(err) } defer ctx.Check(db.Close) // creating tables err = db.CreateTables() if err != nil { - assert.NoError(t, err) + t.Fatal(err) } repository := db.Users()