Satellitedb refactoring (#647)

This commit is contained in:
Yehor Butko 2018-11-14 12:45:49 +00:00 committed by GitHub
parent c442205b3a
commit 8990fea63c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 193 additions and 186 deletions

View File

@ -14,24 +14,23 @@ import (
type Projects interface {
// GetAll is a method for querying all projects from the database.
GetAll(ctx context.Context) ([]Project, error)
// GetByUserID is a method for querying project from the database by userID.
GetByUserID(ctx context.Context, userID uuid.UUID) (*Project, error)
// GetByOwnerID is a method for querying projects from the database by ownerID.
GetByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]Project, error)
// Get is a method for querying project from the database by id.
Get(ctx context.Context, id uuid.UUID) (*Project, error)
// Insert is a method for inserting project into the database.
Insert(ctx context.Context, user *Project) (*Project, error)
Insert(ctx context.Context, project *Project) (*Project, error)
// Delete is a method for deleting project by Id from the database.
Delete(ctx context.Context, id uuid.UUID) error
// Update is a method for updating project entity.
Update(ctx context.Context, user *Project) error
Update(ctx context.Context, project *Project) error
}
// Project is a database object that describes Project entity
type Project struct {
ID uuid.UUID
// FK on Users table. ID of project creator.
// TODO: Should it be named OwnerID?
UserID uuid.UUID
OwnerID *uuid.UUID
Name string
Description string

View File

@ -20,7 +20,6 @@ type companies struct {
// Get is a method for querying company from the database by id
func (companies *companies) Get(ctx context.Context, id uuid.UUID) (*satellite.Company, error) {
company, err := companies.db.Get_Company_By_Id(ctx, dbx.Company_Id(id[:]))
if err != nil {
return nil, err
@ -31,7 +30,6 @@ func (companies *companies) Get(ctx context.Context, id uuid.UUID) (*satellite.C
// Get is a method for querying company from the database by user id
func (companies *companies) GetByUserID(ctx context.Context, userID uuid.UUID) (*satellite.Company, error) {
company, err := companies.db.Get_Company_By_UserId(ctx, dbx.Company_UserId(userID[:]))
if err != nil {
return nil, err
@ -42,7 +40,6 @@ func (companies *companies) GetByUserID(ctx context.Context, userID uuid.UUID) (
// Insert is a method for inserting company into the database
func (companies *companies) Insert(ctx context.Context, company *satellite.Company) (*satellite.Company, error) {
companyID, err := uuid.New()
if err != nil {
return nil, err

View File

@ -14,7 +14,6 @@ import (
)
func TestCompanyRepository(t *testing.T) {
//testing constants
const (
// for user
@ -64,7 +63,6 @@ func TestCompanyRepository(t *testing.T) {
var user *satellite.User
t.Run("Can't insert company without user", func(t *testing.T) {
company := &satellite.Company{
Name: companyName,
Address: address,
@ -82,7 +80,6 @@ func TestCompanyRepository(t *testing.T) {
})
t.Run("Insert company successfully", func(t *testing.T) {
user, err = users.Insert(ctx, &satellite.User{
FirstName: userName,
LastName: lastName,
@ -197,7 +194,6 @@ func TestCompanyRepository(t *testing.T) {
}
func TestCompanyFromDbx(t *testing.T) {
t.Run("can't create dbo from nil dbx model", func(t *testing.T) {
company, err := companyFromDBX(nil)

View File

@ -4,6 +4,7 @@
package satellitedb
import (
"storj.io/storj/internal/migrate"
"storj.io/storj/pkg/satellite"
"storj.io/storj/pkg/satellite/satellitedb/dbx"
@ -46,13 +47,7 @@ func (db *Database) Projects() satellite.Projects {
// CreateTables is a method for creating all tables for satellitedb
func (db *Database) CreateTables() error {
//TODO: this code will be returned in the new commit
//return migrate.Create("satellitedb", db.db)
//TODO: this code should be removed in the new commit
_, err := db.db.Exec(db.db.Schema())
return err
return migrate.Create("satellitedb", db.db)
}
// Close is used to close db connection

View File

@ -59,7 +59,7 @@ model project (
key id
field id blob
field user_id user.id cascade
field owner_id user.id setnull ( nullable, updatable )
field name text ( updatable )
field description text ( updatable )
@ -73,9 +73,9 @@ read one (
select project
where project.id = ?
)
read one (
read all (
select project
where project.user_id = ?
where project.owner_id = ?
)
create project ( )
update project ( where project.id = ? )

View File

@ -290,7 +290,7 @@ CREATE TABLE companies (
);
CREATE TABLE projects (
id BLOB NOT NULL,
user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
owner_id BLOB REFERENCES users( id ) ON DELETE SET NULL,
name TEXT NOT NULL,
description TEXT NOT NULL,
is_agreed_with_terms INTEGER NOT NULL,
@ -672,7 +672,7 @@ func (Company_CreatedAt_Field) _Column() string { return "created_at" }
type Project struct {
Id []byte
UserId []byte
OwnerId []byte
Name string
Description string
IsAgreedWithTerms bool
@ -681,7 +681,12 @@ type Project struct {
func (Project) _Table() string { return "projects" }
type Project_Create_Fields struct {
OwnerId Project_OwnerId_Field
}
type Project_Update_Fields struct {
OwnerId Project_OwnerId_Field
Name Project_Name_Field
Description Project_Description_Field
IsAgreedWithTerms Project_IsAgreedWithTerms_Field
@ -705,23 +710,36 @@ func (f Project_Id_Field) value() interface{} {
func (Project_Id_Field) _Column() string { return "id" }
type Project_UserId_Field struct {
type Project_OwnerId_Field struct {
_set bool
_value []byte
}
func Project_UserId(v []byte) Project_UserId_Field {
return Project_UserId_Field{_set: true, _value: v}
func Project_OwnerId(v []byte) Project_OwnerId_Field {
return Project_OwnerId_Field{_set: true, _value: v}
}
func (f Project_UserId_Field) value() interface{} {
func Project_OwnerId_Raw(v []byte) Project_OwnerId_Field {
if v == nil {
return Project_OwnerId_Null()
}
return Project_OwnerId(v)
}
func Project_OwnerId_Null() Project_OwnerId_Field {
return Project_OwnerId_Field{_set: true}
}
func (f Project_OwnerId_Field) isnull() bool { return !f._set || f._value == nil }
func (f Project_OwnerId_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Project_UserId_Field) _Column() string { return "user_id" }
func (Project_OwnerId_Field) _Column() string { return "owner_id" }
type Project_Name_Field struct {
_set bool
@ -1037,26 +1055,26 @@ func (obj *sqlite3Impl) Create_Company(ctx context.Context,
func (obj *sqlite3Impl) Create_Project(ctx context.Context,
project_id Project_Id_Field,
project_user_id Project_UserId_Field,
project_name Project_Name_Field,
project_description Project_Description_Field,
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field) (
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field,
optional Project_Create_Fields) (
project *Project, err error) {
__now := obj.db.Hooks.Now().UTC()
__id_val := project_id.value()
__user_id_val := project_user_id.value()
__owner_id_val := optional.OwnerId.value()
__name_val := project_name.value()
__description_val := project_description.value()
__is_agreed_with_terms_val := project_is_agreed_with_terms.value()
__created_at_val := __now
var __embed_stmt = __sqlbundle_Literal("INSERT INTO projects ( id, user_id, name, description, is_agreed_with_terms, created_at ) VALUES ( ?, ?, ?, ?, ?, ? )")
var __embed_stmt = __sqlbundle_Literal("INSERT INTO projects ( id, owner_id, name, description, is_agreed_with_terms, created_at ) VALUES ( ?, ?, ?, ?, ?, ? )")
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __id_val, __user_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val)
obj.logStmt(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val)
__res, err := obj.driver.Exec(__stmt, __id_val, __user_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val)
__res, err := obj.driver.Exec(__stmt, __id_val, __owner_id_val, __name_val, __description_val, __is_agreed_with_terms_val, __created_at_val)
if err != nil {
return nil, obj.makeErr(err)
}
@ -1178,7 +1196,7 @@ func (obj *sqlite3Impl) Get_Company_By_Id(ctx context.Context,
func (obj *sqlite3Impl) All_Project(ctx context.Context) (
rows []*Project, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.user_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects")
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects")
var __values []interface{}
__values = append(__values)
@ -1194,7 +1212,7 @@ func (obj *sqlite3Impl) All_Project(ctx context.Context) (
for __rows.Next() {
project := &Project{}
err = __rows.Scan(&project.Id, &project.UserId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
@ -1211,7 +1229,7 @@ func (obj *sqlite3Impl) Get_Project_By_Id(ctx context.Context,
project_id Project_Id_Field) (
project *Project, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.user_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?")
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?")
var __values []interface{}
__values = append(__values, project_id.value())
@ -1220,7 +1238,7 @@ func (obj *sqlite3Impl) Get_Project_By_Id(ctx context.Context,
obj.logStmt(__stmt, __values...)
project = &Project{}
err = obj.driver.QueryRow(__stmt, __values...).Scan(&project.Id, &project.UserId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
err = obj.driver.QueryRow(__stmt, __values...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
@ -1228,14 +1246,21 @@ func (obj *sqlite3Impl) Get_Project_By_Id(ctx context.Context,
}
func (obj *sqlite3Impl) Get_Project_By_UserId(ctx context.Context,
project_user_id Project_UserId_Field) (
project *Project, err error) {
func (obj *sqlite3Impl) All_Project_By_OwnerId(ctx context.Context,
project_owner_id Project_OwnerId_Field) (
rows []*Project, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.user_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.user_id = ? LIMIT 2")
var __cond_0 = &__sqlbundle_Condition{Left: "projects.owner_id", Equal: true, Right: "?", Null: true}
var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE "), __cond_0}}
var __values []interface{}
__values = append(__values, project_user_id.value())
__values = append(__values)
if !project_owner_id.isnull() {
__cond_0.Null = false
__values = append(__values, project_owner_id.value())
}
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
@ -1246,28 +1271,18 @@ func (obj *sqlite3Impl) Get_Project_By_UserId(ctx context.Context,
}
defer __rows.Close()
if !__rows.Next() {
if err := __rows.Err(); err != nil {
return nil, obj.makeErr(err)
}
return nil, makeErr(sql.ErrNoRows)
}
project = &Project{}
err = __rows.Scan(&project.Id, &project.UserId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
for __rows.Next() {
project := &Project{}
err = __rows.Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
if __rows.Next() {
return nil, tooManyRows("Project_By_UserId")
rows = append(rows, project)
}
if err := __rows.Err(); err != nil {
return nil, obj.makeErr(err)
}
return project, nil
return rows, nil
}
@ -1423,6 +1438,11 @@ func (obj *sqlite3Impl) Update_Project_By_Id(ctx context.Context,
var __values []interface{}
var __args []interface{}
if update.OwnerId._set {
__values = append(__values, update.OwnerId.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("owner_id = ?"))
}
if update.Name._set {
__values = append(__values, update.Name.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("name = ?"))
@ -1456,12 +1476,12 @@ func (obj *sqlite3Impl) Update_Project_By_Id(ctx context.Context,
return nil, obj.makeErr(err)
}
var __embed_stmt_get = __sqlbundle_Literal("SELECT projects.id, projects.user_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?")
var __embed_stmt_get = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE projects.id = ?")
var __stmt_get = __sqlbundle_Render(obj.dialect, __embed_stmt_get)
obj.logStmt("(IMPLIED) "+__stmt_get, __args...)
err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&project.Id, &project.UserId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
@ -1589,13 +1609,13 @@ func (obj *sqlite3Impl) getLastProject(ctx context.Context,
pk int64) (
project *Project, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.user_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE _rowid_ = ?")
var __embed_stmt = __sqlbundle_Literal("SELECT projects.id, projects.owner_id, projects.name, projects.description, projects.is_agreed_with_terms, projects.created_at FROM projects WHERE _rowid_ = ?")
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, pk)
project = &Project{}
err = obj.driver.QueryRow(__stmt, pk).Scan(&project.Id, &project.UserId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
err = obj.driver.QueryRow(__stmt, pk).Scan(&project.Id, &project.OwnerId, &project.Name, &project.Description, &project.IsAgreedWithTerms, &project.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
@ -1707,6 +1727,16 @@ func (rx *Rx) All_Project(ctx context.Context) (
return tx.All_Project(ctx)
}
func (rx *Rx) All_Project_By_OwnerId(ctx context.Context,
project_owner_id Project_OwnerId_Field) (
rows []*Project, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.All_Project_By_OwnerId(ctx, project_owner_id)
}
func (rx *Rx) Create_Company(ctx context.Context,
company_id Company_Id_Field,
company_user_id Company_UserId_Field,
@ -1727,16 +1757,16 @@ func (rx *Rx) Create_Company(ctx context.Context,
func (rx *Rx) Create_Project(ctx context.Context,
project_id Project_Id_Field,
project_user_id Project_UserId_Field,
project_name Project_Name_Field,
project_description Project_Description_Field,
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field) (
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field,
optional Project_Create_Fields) (
project *Project, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Create_Project(ctx, project_id, project_user_id, project_name, project_description, project_is_agreed_with_terms)
return tx.Create_Project(ctx, project_id, project_name, project_description, project_is_agreed_with_terms, optional)
}
@ -1815,16 +1845,6 @@ func (rx *Rx) Get_Project_By_Id(ctx context.Context,
return tx.Get_Project_By_Id(ctx, project_id)
}
func (rx *Rx) Get_Project_By_UserId(ctx context.Context,
project_user_id Project_UserId_Field) (
project *Project, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Get_Project_By_UserId(ctx, project_user_id)
}
func (rx *Rx) Get_User_By_Email_And_PasswordHash(ctx context.Context,
user_email User_Email_Field,
user_password_hash User_PasswordHash_Field) (
@ -1883,6 +1903,10 @@ type Methods interface {
All_Project(ctx context.Context) (
rows []*Project, err error)
All_Project_By_OwnerId(ctx context.Context,
project_owner_id Project_OwnerId_Field) (
rows []*Project, err error)
Create_Company(ctx context.Context,
company_id Company_Id_Field,
company_user_id Company_UserId_Field,
@ -1896,10 +1920,10 @@ type Methods interface {
Create_Project(ctx context.Context,
project_id Project_Id_Field,
project_user_id Project_UserId_Field,
project_name Project_Name_Field,
project_description Project_Description_Field,
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field) (
project_is_agreed_with_terms Project_IsAgreedWithTerms_Field,
optional Project_Create_Fields) (
project *Project, err error)
Create_User(ctx context.Context,
@ -1934,10 +1958,6 @@ type Methods interface {
project_id Project_Id_Field) (
project *Project, err error)
Get_Project_By_UserId(ctx context.Context,
project_user_id Project_UserId_Field) (
project *Project, err error)
Get_User_By_Email_And_PasswordHash(ctx context.Context,
user_email User_Email_Field,
user_password_hash User_PasswordHash_Field) (

View File

@ -24,7 +24,7 @@ CREATE TABLE companies (
);
CREATE TABLE projects (
id BLOB NOT NULL,
user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
owner_id BLOB REFERENCES users( id ) ON DELETE SET NULL,
name TEXT NOT NULL,
description TEXT NOT NULL,
is_agreed_with_terms INTEGER NOT NULL,

View File

@ -20,44 +20,26 @@ type projects struct {
// GetAll is a method for querying all projects from the database.
func (projects *projects) GetAll(ctx context.Context) ([]satellite.Project, error) {
projectsDbxSlice, err := projects.db.All_Project(ctx)
projectsDbx, err := projects.db.All_Project(ctx)
if err != nil {
return nil, err
}
projectsCount := len(projectsDbxSlice)
var projectDboSlice []satellite.Project
var errors []error
// Generating []dbo from []dbx and collecting all errors
for i := 0; i < projectsCount; i++ {
projectDbo, err := projectFromDBX(projectsDbxSlice[i])
if err != nil {
errors = append(errors, err)
}
projectDboSlice = append(projectDboSlice, *projectDbo)
}
return projectDboSlice, utils.CombineErrors(errors...)
return projectsFromDbxSlice(projectsDbx)
}
// GetByUserID is a method for querying project from the database by user id
func (projects *projects) GetByUserID(ctx context.Context, userID uuid.UUID) (*satellite.Project, error) {
project, err := projects.db.Get_Project_By_UserId(ctx, dbx.Project_UserId(userID[:]))
// GetByOwnerID is a method for querying projects from the database by ownerID.
func (projects *projects) GetByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]satellite.Project, error) {
projectsDbx, err := projects.db.All_Project_By_OwnerId(ctx, dbx.Project_OwnerId(ownerID[:]))
if err != nil {
return nil, err
}
return projectFromDBX(project)
return projectsFromDbxSlice(projectsDbx)
}
// Get is a method for querying project from the database by id.
func (projects *projects) Get(ctx context.Context, id uuid.UUID) (*satellite.Project, error) {
project, err := projects.db.Get_Project_By_Id(ctx, dbx.Project_Id(id[:]))
if err != nil {
return nil, err
@ -68,18 +50,27 @@ func (projects *projects) Get(ctx context.Context, id uuid.UUID) (*satellite.Pro
// Insert is a method for inserting project into the database.
func (projects *projects) Insert(ctx context.Context, project *satellite.Project) (*satellite.Project, error) {
projectID, err := uuid.New()
if err != nil {
return nil, err
}
var ownerID dbx.Project_OwnerId_Field
if project.OwnerID != nil {
ownerID = dbx.Project_OwnerId(project.OwnerID[:])
} else {
ownerID = dbx.Project_OwnerId(nil)
}
createdProject, err := projects.db.Create_Project(ctx,
dbx.Project_Id(projectID[:]),
dbx.Project_UserId(project.UserID[:]),
dbx.Project_Name(project.Name),
dbx.Project_Description(project.Description),
dbx.Project_IsAgreedWithTerms(project.IsAgreedWithTerms))
dbx.Project_IsAgreedWithTerms(project.IsAgreedWithTerms),
dbx.Project_Create_Fields{
OwnerId: ownerID,
})
if err != nil {
return nil, err
@ -97,7 +88,6 @@ func (projects *projects) Delete(ctx context.Context, id uuid.UUID) error {
// Update is a method for updating user entity
func (projects *projects) Update(ctx context.Context, project *satellite.Project) error {
_, err := projects.db.Update_Project_By_Id(ctx,
dbx.Project_Id(project.ID[:]),
dbx.Project_Update_Fields{
@ -120,19 +110,43 @@ func projectFromDBX(project *dbx.Project) (*satellite.Project, error) {
return nil, err
}
userID, err := bytesToUUID(project.UserId)
if err != nil {
return nil, err
}
u := &satellite.Project{
ID: id,
UserID: userID,
Name: project.Name,
Description: project.Description,
IsAgreedWithTerms: project.IsAgreedWithTerms,
CreatedAt: project.CreatedAt,
}
if project.OwnerId == nil {
u.OwnerID = nil
} else {
ownerID, err := bytesToUUID(project.OwnerId)
if err != nil {
return nil, err
}
u.OwnerID = &ownerID
}
return u, nil
}
// projectsFromDbxSlice is used for creating []Project entities from autogenerated []*dbx.Project struct
func projectsFromDbxSlice(projectsDbx []*dbx.Project) ([]satellite.Project, error) {
var projects []satellite.Project
var errors []error
// Generating []dbo from []dbx and collecting all errors
for _, projectDbx := range projectsDbx {
project, err := projectFromDBX(projectDbx)
if err != nil {
errors = append(errors, err)
continue
}
projects = append(projects, *project)
}
return projects, utils.CombineErrors(errors...)
}

View File

@ -15,7 +15,6 @@ import (
)
func TestProjectsRepository(t *testing.T) {
//testing constants
const (
// for user
@ -25,11 +24,11 @@ func TestProjectsRepository(t *testing.T) {
userName = "name"
// for project
name = "Storj"
name = "Project"
description = "some description"
// updated project values
newName = "Dropbox"
newName = "NewProject"
newDescription = "some new description"
)
@ -54,26 +53,25 @@ func TestProjectsRepository(t *testing.T) {
users := db.Users()
projects := db.Projects()
var user *satellite.User
t.Run("Can't insert project without user", func(t *testing.T) {
var owner *satellite.User
t.Run("Can insert project without owner", func(t *testing.T) {
project := &satellite.Project{
OwnerID: nil,
Name: name,
Description: description,
IsAgreedWithTerms: false,
}
createdCompany, err := projects.Insert(ctx, project)
createdProject, err := projects.Insert(ctx, project)
assert.Nil(t, createdCompany)
assert.NotNil(t, err)
assert.Error(t, err)
assert.NotNil(t, createdProject)
assert.Nil(t, err)
assert.NoError(t, err)
})
t.Run("Insert project successfully", func(t *testing.T) {
user, err = users.Insert(ctx, &satellite.User{
owner, err = users.Insert(ctx, &satellite.User{
FirstName: userName,
LastName: lastName,
Email: email,
@ -81,10 +79,10 @@ func TestProjectsRepository(t *testing.T) {
})
assert.NoError(t, err)
assert.NotNil(t, user)
assert.NotNil(t, owner)
project := &satellite.Project{
UserID: user.ID,
OwnerID: &owner.ID,
Name: name,
Description: description,
@ -99,38 +97,40 @@ func TestProjectsRepository(t *testing.T) {
})
t.Run("Get project success", func(t *testing.T) {
projectByUserID, err := projects.GetByUserID(ctx, user.ID)
projectsByOwnerID, err := projects.GetByOwnerID(ctx, owner.ID)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(projectsByOwnerID), 1)
assert.Equal(t, projectsByOwnerID[0].OwnerID, &owner.ID)
assert.Equal(t, projectsByOwnerID[0].Name, name)
assert.Equal(t, projectsByOwnerID[0].Description, description)
assert.Equal(t, projectsByOwnerID[0].IsAgreedWithTerms, false)
projectByID, err := projects.Get(ctx, projectsByOwnerID[0].ID)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, projectByUserID.UserID, user.ID)
assert.Equal(t, projectByUserID.Name, name)
assert.Equal(t, projectByUserID.Description, description)
assert.Equal(t, projectByUserID.IsAgreedWithTerms, false)
projectByID, err := projects.Get(ctx, projectByUserID.ID)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, projectByID.ID, projectByUserID.ID)
assert.Equal(t, projectByID.UserID, user.ID)
assert.Equal(t, projectByID.ID, projectsByOwnerID[0].ID)
assert.Equal(t, projectByID.OwnerID, &owner.ID)
assert.Equal(t, projectByID.Name, name)
assert.Equal(t, projectByID.Description, description)
assert.Equal(t, projectByID.IsAgreedWithTerms, false)
})
t.Run("Update project success", func(t *testing.T) {
oldProject, err := projects.GetByUserID(ctx, user.ID)
oldProjects, err := projects.GetByOwnerID(ctx, owner.ID)
assert.NoError(t, err)
assert.NotNil(t, oldProject)
assert.NotNil(t, oldProjects)
assert.Equal(t, len(oldProjects), 1)
// creating new company with updated values
// creating new project with updated values
newProject := &satellite.Project{
ID: oldProject.ID,
UserID: user.ID,
ID: oldProjects[0].ID,
OwnerID: &owner.ID,
Name: newName,
Description: newDescription,
IsAgreedWithTerms: true,
@ -142,29 +142,29 @@ func TestProjectsRepository(t *testing.T) {
assert.NoError(t, err)
// fetching updated project from db
newProject, err = projects.Get(ctx, oldProject.ID)
newProject, err = projects.Get(ctx, oldProjects[0].ID)
assert.NoError(t, err)
assert.Equal(t, newProject.ID, oldProject.ID)
assert.Equal(t, newProject.UserID, user.ID)
assert.Equal(t, newProject.ID, oldProjects[0].ID)
assert.Equal(t, newProject.OwnerID, &owner.ID)
assert.Equal(t, newProject.Name, newName)
assert.Equal(t, newProject.Description, newDescription)
assert.Equal(t, newProject.IsAgreedWithTerms, true)
})
t.Run("Delete project success", func(t *testing.T) {
oldProject, err := projects.GetByUserID(ctx, user.ID)
oldProjects, err := projects.GetByOwnerID(ctx, owner.ID)
assert.NoError(t, err)
assert.NotNil(t, oldProject)
assert.NotNil(t, oldProjects)
err = projects.Delete(ctx, oldProject.ID)
err = projects.Delete(ctx, oldProjects[0].ID)
assert.Nil(t, err)
assert.NoError(t, err)
_, err = projects.Get(ctx, oldProject.ID)
_, err = projects.Get(ctx, oldProjects[0].ID)
assert.NotNil(t, err)
assert.Error(t, err)
@ -175,10 +175,10 @@ func TestProjectsRepository(t *testing.T) {
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(allProjects), 0)
assert.Equal(t, len(allProjects), 1)
newProject := &satellite.Project{
UserID: user.ID,
OwnerID: &owner.ID,
Description: description,
Name: name,
IsAgreedWithTerms: true,
@ -193,10 +193,10 @@ func TestProjectsRepository(t *testing.T) {
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(allProjects), 1)
assert.Equal(t, len(allProjects), 2)
newProject2 := &satellite.Project{
UserID: user.ID,
OwnerID: &owner.ID,
Description: description,
Name: name,
IsAgreedWithTerms: true,
@ -211,16 +211,15 @@ func TestProjectsRepository(t *testing.T) {
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(allProjects), 2)
assert.Equal(t, len(allProjects), 3)
})
}
func TestProjectFromDbx(t *testing.T) {
t.Run("can't create dbo from nil dbx model", func(t *testing.T) {
user, err := projectFromDBX(nil)
project, err := projectFromDBX(nil)
assert.Nil(t, user)
assert.Nil(t, project)
assert.NotNil(t, err)
assert.Error(t, err)
})
@ -237,15 +236,14 @@ func TestProjectFromDbx(t *testing.T) {
assert.Error(t, err)
})
t.Run("can't create dbo from dbx model with invalid UserID", func(t *testing.T) {
t.Run("can't create dbo from dbx model with invalid OwnerID", func(t *testing.T) {
projectID, err := uuid.New()
assert.NoError(t, err)
assert.Nil(t, err)
dbxProject := dbx.Project{
Id: projectID[:],
UserId: []byte("qweqwe"),
OwnerId: []byte("qweqwe"),
}
project, err := projectFromDBX(&dbxProject)

View File

@ -20,7 +20,6 @@ type users struct {
// Get is a method for querying user from the database by id
func (users *users) Get(ctx context.Context, id uuid.UUID) (*satellite.User, error) {
user, err := users.db.Get_User_By_Id(ctx, dbx.User_Id(id[:]))
if err != nil {
return nil, err
@ -31,7 +30,6 @@ func (users *users) Get(ctx context.Context, id uuid.UUID) (*satellite.User, err
// GetByCredentials is a method for querying user by credentials from the database.
func (users *users) GetByCredentials(ctx context.Context, password []byte, email string) (*satellite.User, error) {
userEmail := dbx.User_Email(email)
userPassword := dbx.User_PasswordHash(password)
@ -46,7 +44,6 @@ func (users *users) GetByCredentials(ctx context.Context, password []byte, email
// Insert is a method for inserting user into the database
func (users *users) Insert(ctx context.Context, user *satellite.User) (*satellite.User, error) {
userID, err := uuid.New()
if err != nil {
return nil, err
@ -75,7 +72,6 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) error {
// Update is a method for updating user entity
func (users *users) Update(ctx context.Context, user *satellite.User) error {
_, err := users.db.Update_User_By_Id(ctx,
dbx.User_Id(user.ID[:]),
dbx.User_Update_Fields{
@ -89,7 +85,6 @@ func (users *users) Update(ctx context.Context, user *satellite.User) error {
}
// userFromDBX is used for creating User entity from autogenerated dbx.User struct
// TODO: move error strings to better place
func userFromDBX(user *dbx.User) (*satellite.User, error) {
if user == nil {
return nil, errs.New("user parameter is nil")

View File

@ -17,7 +17,6 @@ import (
)
func TestUserRepository(t *testing.T) {
//testing constants
const (
lastName = "lastName"
@ -50,7 +49,6 @@ func TestUserRepository(t *testing.T) {
repository := db.Users()
t.Run("User insertion success", func(t *testing.T) {
id, err := uuid.New()
if err != nil {
@ -73,7 +71,6 @@ func TestUserRepository(t *testing.T) {
})
t.Run("Can't insert user with same email twice", func(t *testing.T) {
user := &satellite.User{
FirstName: name,
LastName: lastName,
@ -96,7 +93,7 @@ func TestUserRepository(t *testing.T) {
assert.Nil(t, err)
assert.NoError(t, err)
userByID, err := repository.GetByCredentials(ctx, []byte(passValid), email)
userByID, err := repository.Get(ctx, userByCreds.ID)
assert.Equal(t, userByID.FirstName, name)
assert.Equal(t, userByID.LastName, lastName)
@ -159,7 +156,6 @@ func TestUserRepository(t *testing.T) {
}
func TestUserFromDbx(t *testing.T) {
t.Run("can't create dbo from nil dbx model", func(t *testing.T) {
user, err := userFromDBX(nil)

View File

@ -11,9 +11,7 @@ import (
)
func TestBytesToUUID(t *testing.T) {
t.Run("Invalid input", func(t *testing.T) {
str := "not UUID string"
bytes := []byte(str)
@ -24,7 +22,6 @@ func TestBytesToUUID(t *testing.T) {
})
t.Run("Valid input", func(t *testing.T) {
id, err := uuid.New()
assert.NoError(t, err)