diff --git a/pkg/satellite/companies.go b/pkg/satellite/companies.go new file mode 100644 index 000000000..4ef899aef --- /dev/null +++ b/pkg/satellite/companies.go @@ -0,0 +1,40 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellite + +import ( + "context" + "time" + + "github.com/skyrings/skyring-common/tools/uuid" +) + +// Companies exposes methods to manage Company table in database. +type Companies interface { + // Get is a method for querying company from the database by id + Get(ctx context.Context, id uuid.UUID) (*Company, error) + // Get is a method for querying company from the database by user id + GetByUserID(ctx context.Context, userID uuid.UUID) (*Company, error) + // Insert is a method for inserting company into the database + Insert(ctx context.Context, company *Company) (*Company, error) + // Delete is a method for deleting company by Id from the database. + Delete(ctx context.Context, id uuid.UUID) error + // Update is a method for updating company entity + Update(ctx context.Context, company *Company) error +} + +// Company is a database object that describes Company entity +type Company struct { + ID uuid.UUID + UserID uuid.UUID + + Name string + Address string + Country string + City string + State string + PostalCode string + + CreatedAt time.Time +} diff --git a/pkg/satellite/database.go b/pkg/satellite/database.go index 24110429f..d50455ecc 100644 --- a/pkg/satellite/database.go +++ b/pkg/satellite/database.go @@ -7,6 +7,8 @@ package satellite type DB interface { // Users is getter for Users repository Users() Users + // Companies is getter for Companies repository + Companies() Companies // CreateTables is a method for creating all tables for satellitedb CreateTables() error diff --git a/pkg/satellite/satellitedb/companies.go b/pkg/satellite/satellitedb/companies.go new file mode 100644 index 000000000..a48c3a449 --- /dev/null +++ b/pkg/satellite/satellitedb/companies.go @@ -0,0 +1,129 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellitedb + +import ( + "context" + + "github.com/zeebo/errs" + + "github.com/skyrings/skyring-common/tools/uuid" + "storj.io/storj/pkg/satellite" + "storj.io/storj/pkg/satellite/satellitedb/dbx" +) + +// implementation of Companies interface repository using spacemonkeygo/dbx orm +type companies struct { + db *dbx.DB +} + +// Get is a method for querying company from the database by id +func (companies *companies) Get(ctx context.Context, id uuid.UUID) (*satellite.Company, error) { + + company, err := companies.db.Get_Company_By_Id(ctx, dbx.Company_Id([]byte(id.String()))) + + if err != nil { + return nil, err + } + + return companyFromDBX(company) +} + +// Get is a method for querying company from the database by user id +func (companies *companies) GetByUserID(ctx context.Context, userID uuid.UUID) (*satellite.Company, error) { + + company, err := companies.db.Get_Company_By_UserId(ctx, dbx.Company_UserId([]byte(userID.String()))) + + if err != nil { + return nil, err + } + + return companyFromDBX(company) +} + +// Insert is a method for inserting company into the database +func (companies *companies) Insert(ctx context.Context, company *satellite.Company) (*satellite.Company, error) { + + companyID, err := uuid.New() + if err != nil { + return nil, err + } + + createdCompany, err := companies.db.Create_Company( + ctx, + dbx.Company_Id([]byte(companyID.String())), + dbx.Company_UserId([]byte(company.UserID.String())), + dbx.Company_Name(company.Name), + dbx.Company_Address(company.Address), + dbx.Company_Country(company.Country), + dbx.Company_City(company.City), + dbx.Company_State(company.State), + dbx.Company_PostalCode(company.PostalCode)) + + if err != nil { + return nil, err + } + + return companyFromDBX(createdCompany) +} + +// Delete is a method for deleting company by Id from the database. +func (companies *companies) Delete(ctx context.Context, id uuid.UUID) error { + _, err := companies.db.Delete_Company_By_Id(ctx, dbx.Company_Id([]byte(id.String()))) + + return err +} + +// Update is a method for updating company entity +func (companies *companies) Update(ctx context.Context, company *satellite.Company) error { + _, err := companies.db.Update_Company_By_Id( + ctx, + dbx.Company_Id([]byte(company.ID.String())), + *getCompanyUpdateFields(company)) + + return err +} + +// companyFromDBX is used for creating Company entity from autogenerated dbx.Company struct +func companyFromDBX(company *dbx.Company) (*satellite.Company, error) { + if company == nil { + return nil, errs.New("company parameter is nil") + } + + id, err := uuid.Parse(string(company.Id)) + if err != nil { + return nil, errs.New("Id in not valid UUID string") + } + + userID, err := uuid.Parse(string(company.UserId)) + if err != nil { + return nil, errs.New("UserID in not valid UUID string") + } + + comp := &satellite.Company{ + ID: *id, + UserID: *userID, + Name: company.Name, + Address: company.Address, + Country: company.Country, + City: company.City, + State: company.State, + PostalCode: company.PostalCode, + CreatedAt: company.CreatedAt, + } + + return comp, nil +} + +// getCompanyUpdateFields is used to generate company update fields +func getCompanyUpdateFields(company *satellite.Company) *dbx.Company_Update_Fields { + return &dbx.Company_Update_Fields{ + Name: dbx.Company_Name(company.Name), + Address: dbx.Company_Address(company.Address), + Country: dbx.Company_Country(company.Country), + City: dbx.Company_City(company.City), + State: dbx.Company_State(company.State), + PostalCode: dbx.Company_PostalCode(company.PostalCode), + } +} diff --git a/pkg/satellite/satellitedb/companies_test.go b/pkg/satellite/satellitedb/companies_test.go new file mode 100644 index 000000000..a4a0aa0a0 --- /dev/null +++ b/pkg/satellite/satellitedb/companies_test.go @@ -0,0 +1,238 @@ +// Copyright (C) 2018 Storj Labs, Inc. +// See LICENSE for copying information. + +package satellitedb + +import ( + "testing" + + "github.com/skyrings/skyring-common/tools/uuid" + "github.com/stretchr/testify/assert" + "storj.io/storj/internal/testcontext" + "storj.io/storj/pkg/satellite" + "storj.io/storj/pkg/satellite/satellitedb/dbx" +) + +func TestCompanyRepository(t *testing.T) { + + //testing constants + const ( + // for user + lastName = "lastName" + email = "email@ukr.net" + pass = "123456" + userName = "name" + + // for company + companyName = "Storj" + address = "somewhere" + country = "USA" + city = "Atlanta" + state = "Georgia" + postalCode = "02183" + + // updated company values + newCompanyName = "Storage" + newAddress = "where" + newCountry = "Usa" + newCity = "Otlanta" + newState = "Jeorgia" + newPostalCode = "02184" + ) + + ctx := testcontext.New(t) + defer ctx.Cleanup() + + // creating in-memory db and opening connection + // to test with real db3 file use this connection string - "../db/accountdb.db3" + db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") + if err != nil { + assert.NoError(t, err) + } + defer ctx.Check(db.Close) + + // creating tables + err = db.CreateTables() + if err != nil { + assert.NoError(t, err) + } + + // repositories + users := db.Users() + companies := db.Companies() + + var user *satellite.User + + t.Run("Can't insert company without user", func(t *testing.T) { + + company := &satellite.Company{ + Name: companyName, + Address: address, + Country: country, + City: city, + State: state, + PostalCode: postalCode, + } + + createdCompany, err := companies.Insert(ctx, company) + + assert.Nil(t, createdCompany) + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("Insert company successfully", func(t *testing.T) { + + user, err = users.Insert(ctx, &satellite.User{ + FirstName: userName, + LastName: lastName, + Email: email, + PasswordHash: []byte(pass), + }) + + assert.NoError(t, err) + assert.NotNil(t, user) + + company := &satellite.Company{ + UserID: user.ID, + + Name: companyName, + Address: address, + Country: country, + City: city, + State: state, + PostalCode: postalCode, + } + + createdCompany, err := companies.Insert(ctx, company) + + assert.NotNil(t, createdCompany) + assert.Nil(t, err) + assert.NoError(t, err) + }) + + t.Run("Get company success", func(t *testing.T) { + companyByUserID, err := companies.GetByUserID(ctx, user.ID) + + assert.Nil(t, err) + assert.NoError(t, err) + + assert.Equal(t, companyByUserID.UserID, user.ID) + assert.Equal(t, companyByUserID.Name, companyName) + assert.Equal(t, companyByUserID.Address, address) + assert.Equal(t, companyByUserID.Country, country) + assert.Equal(t, companyByUserID.City, city) + assert.Equal(t, companyByUserID.State, state) + assert.Equal(t, companyByUserID.PostalCode, postalCode) + + companyByID, err := companies.Get(ctx, companyByUserID.ID) + + assert.Nil(t, err) + assert.NoError(t, err) + + assert.Equal(t, companyByID.ID, companyByUserID.ID) + assert.Equal(t, companyByID.UserID, user.ID) + assert.Equal(t, companyByID.Name, companyName) + assert.Equal(t, companyByID.Address, address) + assert.Equal(t, companyByID.Country, country) + assert.Equal(t, companyByID.City, city) + assert.Equal(t, companyByID.State, state) + assert.Equal(t, companyByID.PostalCode, postalCode) + }) + + t.Run("Update company success", func(t *testing.T) { + oldCompany, err := companies.GetByUserID(ctx, user.ID) + + assert.NoError(t, err) + assert.NotNil(t, oldCompany) + + // creating new company with updated values + newCompany := &satellite.Company{ + ID: oldCompany.ID, + UserID: user.ID, + Name: newCompanyName, + Address: newAddress, + Country: newCountry, + City: newCity, + State: newState, + PostalCode: newPostalCode, + } + + err = companies.Update(ctx, newCompany) + + assert.Nil(t, err) + assert.NoError(t, err) + + // fetching updated company from db + newCompany, err = companies.Get(ctx, oldCompany.ID) + + assert.NoError(t, err) + + assert.Equal(t, newCompany.ID, oldCompany.ID) + assert.Equal(t, newCompany.UserID, user.ID) + assert.Equal(t, newCompany.Name, newCompanyName) + assert.Equal(t, newCompany.Address, newAddress) + assert.Equal(t, newCompany.Country, newCountry) + assert.Equal(t, newCompany.City, newCity) + assert.Equal(t, newCompany.State, newState) + assert.Equal(t, newCompany.PostalCode, newPostalCode) + }) + + t.Run("Delete company success", func(t *testing.T) { + oldCompany, err := companies.GetByUserID(ctx, user.ID) + + assert.NoError(t, err) + assert.NotNil(t, oldCompany) + + err = companies.Delete(ctx, oldCompany.ID) + + assert.Nil(t, err) + assert.NoError(t, err) + + _, err = companies.Get(ctx, oldCompany.ID) + + assert.NotNil(t, err) + assert.Error(t, err) + }) +} + +func TestCompanyFromDbx(t *testing.T) { + + t.Run("can't create dbo from nil dbx model", func(t *testing.T) { + user, err := companyFromDBX(nil) + + assert.Nil(t, user) + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("can't create dbo from dbx model with invalid ID", func(t *testing.T) { + dbxCompany := dbx.Company{ + Id: []byte("qweqwe"), + } + + user, err := companyFromDBX(&dbxCompany) + + assert.Nil(t, user) + assert.NotNil(t, err) + assert.Error(t, err) + }) + + t.Run("can't create dbo from dbx model with invalid UserID", func(t *testing.T) { + + companyID, err := uuid.New() + assert.NoError(t, err) + assert.Nil(t, err) + + dbxCompany := dbx.Company{ + Id: []byte(companyID.String()), + UserId: []byte("qweqwe"), + } + + user, err := companyFromDBX(&dbxCompany) + + assert.Nil(t, user) + assert.NotNil(t, err) + assert.Error(t, err) + }) +} diff --git a/pkg/satellite/satellitedb/db.go b/pkg/satellite/satellitedb/db.go index b2027ec1b..ac55a9beb 100644 --- a/pkg/satellite/satellitedb/db.go +++ b/pkg/satellite/satellitedb/db.go @@ -34,6 +34,11 @@ func (db *Database) Users() satellite.Users { return &users{db.db} } +// Companies is getter for Companies repository +func (db *Database) Companies() satellite.Companies { + return &companies{db.db} +} + // CreateTables is a method for creating all tables for satellitedb func (db *Database) CreateTables() error { _, err := db.db.Exec(db.db.Schema()) diff --git a/pkg/satellite/satellitedb/dbx/accountdb.dbx b/pkg/satellite/satellitedb/dbx/accountdb.dbx deleted file mode 100644 index 7b57b1fc5..000000000 --- a/pkg/satellite/satellitedb/dbx/accountdb.dbx +++ /dev/null @@ -1,35 +0,0 @@ -// dbx.v1 golang satellitedb.dbx . - -model user ( - key id - unique email - - field id text - field first_name text ( updatable ) - field last_name text ( updatable ) - field email text ( updatable ) - field password_hash blob ( updatable ) - - field created_at timestamp ( autoinsert ) -) - -read one ( - select user - where user.email = ? - where user.password_hash = ? -) -read one ( - select user - where user.id = ? -) -create user ( ) -update user ( where user.id = ? ) -delete user ( where user.id = ? ) - -//TODO: this entity will be used and updated in the next commit -model company ( - key id - - field id blob - field userId user.id cascade ( updatable ) -) \ No newline at end of file diff --git a/pkg/satellite/satellitedb/dbx/satellitedb.dbx b/pkg/satellite/satellitedb/dbx/satellitedb.dbx new file mode 100644 index 000000000..e11bfa3d0 --- /dev/null +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx @@ -0,0 +1,56 @@ +// dbx.v1 golang satellitedb.dbx . + +model user ( + key id + unique email + + field id blob + field first_name text ( updatable ) + field last_name text ( updatable ) + field email text ( updatable ) + field password_hash blob ( updatable ) + + field created_at timestamp ( autoinsert ) +) + +read one ( + select user + where user.email = ? + where user.password_hash = ? +) +read one ( + select user + where user.id = ? +) +create user ( ) +update user ( where user.id = ? ) +delete user ( where user.id = ? ) + + +model company ( + key id + + field id blob + field user_id user.id cascade + + field name text ( updatable ) + field address text ( updatable ) + field country text ( updatable ) + field city text ( updatable ) + field state text ( updatable ) + field postal_code text ( updatable ) + + field created_at timestamp ( autoinsert ) +) + +read one ( + select company + where company.user_id = ? +) +read one ( + select company + where company.id = ? +) +create company ( ) +update company ( where company.id = ? ) +delete company ( where company.id = ? ) \ No newline at end of file diff --git a/pkg/satellite/satellitedb/dbx/accountdb.dbx.go b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go similarity index 63% rename from pkg/satellite/satellitedb/dbx/accountdb.dbx.go rename to pkg/satellite/satellitedb/dbx/satellitedb.dbx.go index 6a28ac403..0277a2159 100644 --- a/pkg/satellite/satellitedb/dbx/accountdb.dbx.go +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -267,7 +267,7 @@ func newsqlite3(db *DB) *sqlite3DB { func (obj *sqlite3DB) Schema() string { return `CREATE TABLE users ( - id TEXT NOT NULL, + id BLOB NOT NULL, first_name TEXT NOT NULL, last_name TEXT NOT NULL, email TEXT NOT NULL, @@ -278,7 +278,14 @@ func (obj *sqlite3DB) Schema() string { ); CREATE TABLE companies ( id BLOB NOT NULL, - userId TEXT NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + name TEXT NOT NULL, + address TEXT NOT NULL, + country TEXT NOT NULL, + city TEXT NOT NULL, + state TEXT NOT NULL, + postal_code TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, PRIMARY KEY ( id ) );` } @@ -344,7 +351,7 @@ nextval: } type User struct { - Id string + Id []byte FirstName string LastName string Email string @@ -363,10 +370,10 @@ type User_Update_Fields struct { type User_Id_Field struct { _set bool - _value string + _value []byte } -func User_Id(v string) User_Id_Field { +func User_Id(v []byte) User_Id_Field { return User_Id_Field{_set: true, _value: v} } @@ -470,14 +477,26 @@ func (f User_CreatedAt_Field) value() interface{} { func (User_CreatedAt_Field) _Column() string { return "created_at" } type Company struct { - Id []byte - UserId string + Id []byte + UserId []byte + Name string + Address string + Country string + City string + State string + PostalCode string + CreatedAt time.Time } func (Company) _Table() string { return "companies" } type Company_Update_Fields struct { - UserId Company_UserId_Field + Name Company_Name_Field + Address Company_Address_Field + Country Company_Country_Field + City Company_City_Field + State Company_State_Field + PostalCode Company_PostalCode_Field } type Company_Id_Field struct { @@ -500,10 +519,10 @@ func (Company_Id_Field) _Column() string { return "id" } type Company_UserId_Field struct { _set bool - _value string + _value []byte } -func Company_UserId(v string) Company_UserId_Field { +func Company_UserId(v []byte) Company_UserId_Field { return Company_UserId_Field{_set: true, _value: v} } @@ -514,7 +533,133 @@ func (f Company_UserId_Field) value() interface{} { return f._value } -func (Company_UserId_Field) _Column() string { return "userId" } +func (Company_UserId_Field) _Column() string { return "user_id" } + +type Company_Name_Field struct { + _set bool + _value string +} + +func Company_Name(v string) Company_Name_Field { + return Company_Name_Field{_set: true, _value: v} +} + +func (f Company_Name_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_Name_Field) _Column() string { return "name" } + +type Company_Address_Field struct { + _set bool + _value string +} + +func Company_Address(v string) Company_Address_Field { + return Company_Address_Field{_set: true, _value: v} +} + +func (f Company_Address_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_Address_Field) _Column() string { return "address" } + +type Company_Country_Field struct { + _set bool + _value string +} + +func Company_Country(v string) Company_Country_Field { + return Company_Country_Field{_set: true, _value: v} +} + +func (f Company_Country_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_Country_Field) _Column() string { return "country" } + +type Company_City_Field struct { + _set bool + _value string +} + +func Company_City(v string) Company_City_Field { + return Company_City_Field{_set: true, _value: v} +} + +func (f Company_City_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_City_Field) _Column() string { return "city" } + +type Company_State_Field struct { + _set bool + _value string +} + +func Company_State(v string) Company_State_Field { + return Company_State_Field{_set: true, _value: v} +} + +func (f Company_State_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_State_Field) _Column() string { return "state" } + +type Company_PostalCode_Field struct { + _set bool + _value string +} + +func Company_PostalCode(v string) Company_PostalCode_Field { + return Company_PostalCode_Field{_set: true, _value: v} +} + +func (f Company_PostalCode_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_PostalCode_Field) _Column() string { return "postal_code" } + +type Company_CreatedAt_Field struct { + _set bool + _value time.Time +} + +func Company_CreatedAt(v time.Time) Company_CreatedAt_Field { + return Company_CreatedAt_Field{_set: true, _value: v} +} + +func (f Company_CreatedAt_Field) value() interface{} { + if !f._set { + return nil + } + return f._value +} + +func (Company_CreatedAt_Field) _Column() string { return "created_at" } func toUTC(t time.Time) time.Time { return t.UTC() @@ -717,6 +862,45 @@ func (obj *sqlite3Impl) Create_User(ctx context.Context, } +func (obj *sqlite3Impl) Create_Company(ctx context.Context, + company_id Company_Id_Field, + company_user_id Company_UserId_Field, + company_name Company_Name_Field, + company_address Company_Address_Field, + company_country Company_Country_Field, + company_city Company_City_Field, + company_state Company_State_Field, + company_postal_code Company_PostalCode_Field) ( + company *Company, err error) { + + __now := obj.db.Hooks.Now().UTC() + __id_val := company_id.value() + __user_id_val := company_user_id.value() + __name_val := company_name.value() + __address_val := company_address.value() + __country_val := company_country.value() + __city_val := company_city.value() + __state_val := company_state.value() + __postal_code_val := company_postal_code.value() + __created_at_val := __now + + var __embed_stmt = __sqlbundle_Literal("INSERT INTO companies ( id, user_id, name, address, country, city, state, postal_code, created_at ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? )") + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __id_val, __user_id_val, __name_val, __address_val, __country_val, __city_val, __state_val, __postal_code_val, __created_at_val) + + __res, err := obj.driver.Exec(__stmt, __id_val, __user_id_val, __name_val, __address_val, __country_val, __city_val, __state_val, __postal_code_val, __created_at_val) + if err != nil { + return nil, obj.makeErr(err) + } + __pk, err := __res.LastInsertId() + if err != nil { + return nil, obj.makeErr(err) + } + return obj.getLastCompany(ctx, __pk) + +} + func (obj *sqlite3Impl) Get_User_By_Email_And_PasswordHash(ctx context.Context, user_email User_Email_Field, user_password_hash User_PasswordHash_Field) ( @@ -760,6 +944,70 @@ func (obj *sqlite3Impl) Get_User_By_Id(ctx context.Context, } +func (obj *sqlite3Impl) Get_Company_By_UserId(ctx context.Context, + company_user_id Company_UserId_Field) ( + company *Company, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.user_id = ? LIMIT 2") + + var __values []interface{} + __values = append(__values, company_user_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __rows, err := obj.driver.Query(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + defer __rows.Close() + + if !__rows.Next() { + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + return nil, makeErr(sql.ErrNoRows) + } + + company = &Company{} + err = __rows.Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + + if __rows.Next() { + return nil, tooManyRows("Company_By_UserId") + } + + if err := __rows.Err(); err != nil { + return nil, obj.makeErr(err) + } + + return company, nil + +} + +func (obj *sqlite3Impl) Get_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + company *Company, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.id = ?") + + var __values []interface{} + __values = append(__values, company_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + company = &Company{} + err = obj.driver.QueryRow(__stmt, __values...).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + return company, nil + +} + func (obj *sqlite3Impl) Update_User_By_Id(ctx context.Context, user_id User_Id_Field, update User_Update_Fields) ( @@ -825,6 +1073,81 @@ func (obj *sqlite3Impl) Update_User_By_Id(ctx context.Context, return user, nil } +func (obj *sqlite3Impl) Update_Company_By_Id(ctx context.Context, + company_id Company_Id_Field, + update Company_Update_Fields) ( + company *Company, err error) { + var __sets = &__sqlbundle_Hole{} + + var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE companies SET "), __sets, __sqlbundle_Literal(" WHERE companies.id = ?")}} + + __sets_sql := __sqlbundle_Literals{Join: ", "} + var __values []interface{} + var __args []interface{} + + if update.Name._set { + __values = append(__values, update.Name.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("name = ?")) + } + + if update.Address._set { + __values = append(__values, update.Address.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("address = ?")) + } + + if update.Country._set { + __values = append(__values, update.Country.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("country = ?")) + } + + if update.City._set { + __values = append(__values, update.City.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("city = ?")) + } + + if update.State._set { + __values = append(__values, update.State.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("state = ?")) + } + + if update.PostalCode._set { + __values = append(__values, update.PostalCode.value()) + __sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("postal_code = ?")) + } + + if len(__sets_sql.SQLs) == 0 { + return nil, emptyUpdate() + } + + __args = append(__args, company_id.value()) + + __values = append(__values, __args...) + __sets.SQL = __sets_sql + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + company = &Company{} + _, err = obj.driver.Exec(__stmt, __values...) + if err != nil { + return nil, obj.makeErr(err) + } + + var __embed_stmt_get = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.id = ?") + + var __stmt_get = __sqlbundle_Render(obj.dialect, __embed_stmt_get) + obj.logStmt("(IMPLIED) "+__stmt_get, __args...) + + err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, obj.makeErr(err) + } + return company, nil +} + func (obj *sqlite3Impl) Delete_User_By_Id(ctx context.Context, user_id User_Id_Field) ( deleted bool, err error) { @@ -851,6 +1174,32 @@ func (obj *sqlite3Impl) Delete_User_By_Id(ctx context.Context, } +func (obj *sqlite3Impl) Delete_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + deleted bool, err error) { + + var __embed_stmt = __sqlbundle_Literal("DELETE FROM companies WHERE companies.id = ?") + + var __values []interface{} + __values = append(__values, company_id.value()) + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + __res, err := obj.driver.Exec(__stmt, __values...) + if err != nil { + return false, obj.makeErr(err) + } + + __count, err := __res.RowsAffected() + if err != nil { + return false, obj.makeErr(err) + } + + return __count > 0, nil + +} + func (obj *sqlite3Impl) getLastUser(ctx context.Context, pk int64) ( user *User, err error) { @@ -869,6 +1218,24 @@ func (obj *sqlite3Impl) getLastUser(ctx context.Context, } +func (obj *sqlite3Impl) getLastCompany(ctx context.Context, + pk int64) ( + company *Company, err error) { + + var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE _rowid_ = ?") + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, pk) + + company = &Company{} + err = obj.driver.QueryRow(__stmt, pk).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt) + if err != nil { + return nil, obj.makeErr(err) + } + return company, nil + +} + func (impl sqlite3Impl) isConstraintError(err error) ( constraint string, ok bool) { if e, ok := err.(sqlite3.Error); ok { @@ -954,6 +1321,24 @@ func (rx *Rx) Rollback() (err error) { return err } +func (rx *Rx) Create_Company(ctx context.Context, + company_id Company_Id_Field, + company_user_id Company_UserId_Field, + company_name Company_Name_Field, + company_address Company_Address_Field, + company_country Company_Country_Field, + company_city Company_City_Field, + company_state Company_State_Field, + company_postal_code Company_PostalCode_Field) ( + company *Company, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Create_Company(ctx, company_id, company_user_id, company_name, company_address, company_country, company_city, company_state, company_postal_code) + +} + func (rx *Rx) Create_User(ctx context.Context, user_id User_Id_Field, user_first_name User_FirstName_Field, @@ -969,6 +1354,16 @@ func (rx *Rx) Create_User(ctx context.Context, } +func (rx *Rx) Delete_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + deleted bool, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Delete_Company_By_Id(ctx, company_id) +} + func (rx *Rx) Delete_User_By_Id(ctx context.Context, user_id User_Id_Field) ( deleted bool, err error) { @@ -979,6 +1374,26 @@ func (rx *Rx) Delete_User_By_Id(ctx context.Context, return tx.Delete_User_By_Id(ctx, user_id) } +func (rx *Rx) Get_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + company *Company, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Get_Company_By_Id(ctx, company_id) +} + +func (rx *Rx) Get_Company_By_UserId(ctx context.Context, + company_user_id Company_UserId_Field) ( + company *Company, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Get_Company_By_UserId(ctx, company_user_id) +} + func (rx *Rx) Get_User_By_Email_And_PasswordHash(ctx context.Context, user_email User_Email_Field, user_password_hash User_PasswordHash_Field) ( @@ -1000,6 +1415,17 @@ func (rx *Rx) Get_User_By_Id(ctx context.Context, return tx.Get_User_By_Id(ctx, user_id) } +func (rx *Rx) Update_Company_By_Id(ctx context.Context, + company_id Company_Id_Field, + update Company_Update_Fields) ( + company *Company, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.Update_Company_By_Id(ctx, company_id, update) +} + func (rx *Rx) Update_User_By_Id(ctx context.Context, user_id User_Id_Field, update User_Update_Fields) ( @@ -1012,6 +1438,17 @@ func (rx *Rx) Update_User_By_Id(ctx context.Context, } type Methods interface { + Create_Company(ctx context.Context, + company_id Company_Id_Field, + company_user_id Company_UserId_Field, + company_name Company_Name_Field, + company_address Company_Address_Field, + company_country Company_Country_Field, + company_city Company_City_Field, + company_state Company_State_Field, + company_postal_code Company_PostalCode_Field) ( + company *Company, err error) + Create_User(ctx context.Context, user_id User_Id_Field, user_first_name User_FirstName_Field, @@ -1020,10 +1457,22 @@ type Methods interface { user_password_hash User_PasswordHash_Field) ( user *User, err error) + Delete_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + deleted bool, err error) + Delete_User_By_Id(ctx context.Context, user_id User_Id_Field) ( deleted bool, err error) + Get_Company_By_Id(ctx context.Context, + company_id Company_Id_Field) ( + company *Company, err error) + + Get_Company_By_UserId(ctx context.Context, + company_user_id Company_UserId_Field) ( + company *Company, err error) + Get_User_By_Email_And_PasswordHash(ctx context.Context, user_email User_Email_Field, user_password_hash User_PasswordHash_Field) ( @@ -1033,6 +1482,11 @@ type Methods interface { user_id User_Id_Field) ( user *User, err error) + Update_Company_By_Id(ctx context.Context, + company_id Company_Id_Field, + update Company_Update_Fields) ( + company *Company, err error) + Update_User_By_Id(ctx context.Context, user_id User_Id_Field, update User_Update_Fields) ( diff --git a/pkg/satellite/satellitedb/dbx/accountdb.dbx.sqlite3.sql b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql similarity index 56% rename from pkg/satellite/satellitedb/dbx/accountdb.dbx.sqlite3.sql rename to pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql index e7f386cb7..22ca20f4b 100644 --- a/pkg/satellite/satellitedb/dbx/accountdb.dbx.sqlite3.sql +++ b/pkg/satellite/satellitedb/dbx/satellitedb.dbx.sqlite3.sql @@ -1,7 +1,7 @@ -- AUTOGENERATED BY gopkg.in/spacemonkeygo/dbx.v1 -- DO NOT EDIT CREATE TABLE users ( - id TEXT NOT NULL, + id BLOB NOT NULL, first_name TEXT NOT NULL, last_name TEXT NOT NULL, email TEXT NOT NULL, @@ -12,6 +12,13 @@ CREATE TABLE users ( ); CREATE TABLE companies ( id BLOB NOT NULL, - userId TEXT NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE, + name TEXT NOT NULL, + address TEXT NOT NULL, + country TEXT NOT NULL, + city TEXT NOT NULL, + state TEXT NOT NULL, + postal_code TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, PRIMARY KEY ( id ) ); diff --git a/pkg/satellite/satellitedb/users.go b/pkg/satellite/satellitedb/users.go index d8de6079e..3cc6ba969 100644 --- a/pkg/satellite/satellitedb/users.go +++ b/pkg/satellite/satellitedb/users.go @@ -14,7 +14,7 @@ import ( "storj.io/storj/pkg/satellite/satellitedb/dbx" ) -// implementation of User interface repository using spacemonkeygo/dbx orm +// implementation of Users interface repository using spacemonkeygo/dbx orm type users struct { db *dbx.DB } @@ -22,7 +22,7 @@ type users struct { // Get is a method for querying user from the database by id func (users *users) Get(ctx context.Context, id uuid.UUID) (*satellite.User, error) { - userID := dbx.User_Id(id.String()) + userID := dbx.User_Id([]byte(id.String())) user, err := users.db.Get_User_By_Id(ctx, userID) @@ -49,20 +49,30 @@ func (users *users) GetByCredentials(ctx context.Context, password []byte, email } // Insert is a method for inserting user into the database -func (users *users) Insert(ctx context.Context, user *satellite.User) error { - _, err := users.db.Create_User(ctx, - dbx.User_Id(user.ID.String()), +func (users *users) Insert(ctx context.Context, user *satellite.User) (*satellite.User, error) { + + userID, err := uuid.New() + if err != nil { + return nil, err + } + + createdUser, err := users.db.Create_User(ctx, + dbx.User_Id([]byte(userID.String())), dbx.User_FirstName(user.FirstName), dbx.User_LastName(user.LastName), dbx.User_Email(user.Email), dbx.User_PasswordHash(user.PasswordHash)) - return err + if err != nil { + return nil, err + } + + return userFromDBX(createdUser) } // Delete is a method for deleting user by Id from the database. func (users *users) Delete(ctx context.Context, id uuid.UUID) error { - _, err := users.db.Delete_User_By_Id(ctx, dbx.User_Id(id.String())) + _, err := users.db.Delete_User_By_Id(ctx, dbx.User_Id([]byte(id.String()))) return err } @@ -70,7 +80,7 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) error { // Update is a method for updating user entity func (users *users) Update(ctx context.Context, user *satellite.User) error { _, err := users.db.Update_User_By_Id(ctx, - dbx.User_Id(user.ID.String()), + dbx.User_Id([]byte(user.ID.String())), dbx.User_Update_Fields{ FirstName: dbx.User_FirstName(user.FirstName), LastName: dbx.User_LastName(user.LastName), @@ -88,20 +98,19 @@ func userFromDBX(user *dbx.User) (*satellite.User, error) { return nil, errs.New("user parameter is nil") } - id, err := uuid.Parse(user.Id) - + id, err := uuid.Parse(string(user.Id)) if err != nil { return nil, errs.New("Id in not valid UUID string") } - u := &satellite.User{} - - u.ID = *id - u.FirstName = user.FirstName - u.LastName = user.LastName - u.Email = user.Email - u.PasswordHash = user.PasswordHash - u.CreatedAt = user.CreatedAt + u := &satellite.User{ + ID: *id, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + PasswordHash: user.PasswordHash, + CreatedAt: user.CreatedAt, + } return u, nil } diff --git a/pkg/satellite/satellitedb/users_test.go b/pkg/satellite/satellitedb/users_test.go index fc0e17442..eff642c3b 100644 --- a/pkg/satellite/satellitedb/users_test.go +++ b/pkg/satellite/satellitedb/users_test.go @@ -16,7 +16,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRepository(t *testing.T) { +func TestUserRepository(t *testing.T) { //testing constants const ( @@ -33,6 +33,7 @@ func TestRepository(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() + // creating in-memory db and opens connection // to test with real db3 file use this connection string - "../db/accountdb.db3" db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared") if err != nil { @@ -40,6 +41,7 @@ func TestRepository(t *testing.T) { } defer ctx.Check(db.Close) + // creating tables err = db.CreateTables() if err != nil { assert.NoError(t, err) @@ -64,7 +66,7 @@ func TestRepository(t *testing.T) { CreatedAt: time.Now(), } - err = repository.Insert(ctx, user) + _, err = repository.Insert(ctx, user) assert.Nil(t, err) assert.NoError(t, err) @@ -72,14 +74,7 @@ func TestRepository(t *testing.T) { t.Run("Can't insert user with same email twice", func(t *testing.T) { - id, err := uuid.New() - - if err != nil { - assert.NoError(t, err) - } - user := &satellite.User{ - ID: *id, FirstName: name, LastName: lastName, Email: email, @@ -87,7 +82,7 @@ func TestRepository(t *testing.T) { CreatedAt: time.Now(), } - err = repository.Insert(ctx, user) + _, err = repository.Insert(ctx, user) assert.NotNil(t, err) assert.Error(t, err) @@ -119,9 +114,7 @@ func TestRepository(t *testing.T) { t.Run("Update user success", func(t *testing.T) { oldUser, err := repository.GetByCredentials(ctx, []byte(passValid), email) - if err != nil { - assert.NoError(t, err) - } + assert.NoError(t, err) newUser := &satellite.User{ ID: oldUser.ID, @@ -129,7 +122,6 @@ func TestRepository(t *testing.T) { LastName: newLastName, Email: newEmail, PasswordHash: []byte(newPass), - CreatedAt: oldUser.CreatedAt, } err = repository.Update(ctx, newUser) @@ -139,9 +131,7 @@ func TestRepository(t *testing.T) { newUser, err = repository.Get(ctx, oldUser.ID) - if err != nil { - assert.NoError(t, err) - } + assert.NoError(t, err) assert.Equal(t, newUser.ID, oldUser.ID) assert.Equal(t, newUser.FirstName, newName) @@ -154,9 +144,7 @@ func TestRepository(t *testing.T) { t.Run("Delete user success", func(t *testing.T) { oldUser, err := repository.GetByCredentials(ctx, []byte(newPass), newEmail) - if err != nil { - assert.NoError(t, err) - } + assert.NoError(t, err) err = repository.Delete(ctx, oldUser.ID) @@ -170,7 +158,7 @@ func TestRepository(t *testing.T) { }) } -func TestUserDboFromDbx(t *testing.T) { +func TestUserFromDbx(t *testing.T) { t.Run("can't create dbo from nil dbx model", func(t *testing.T) { user, err := userFromDBX(nil) @@ -180,9 +168,9 @@ func TestUserDboFromDbx(t *testing.T) { assert.Error(t, err) }) - t.Run("can't create dbo from dbx model with invalid Id", func(t *testing.T) { + t.Run("can't create dbo from dbx model with invalid ID", func(t *testing.T) { dbxUser := dbx.User{ - Id: "qweqwe", + Id: []byte("qweqwe"), FirstName: "FirstName", LastName: "LastName", Email: "email@ukr.net", diff --git a/pkg/satellite/users.go b/pkg/satellite/users.go index 44dafbce9..26f9ebeb4 100644 --- a/pkg/satellite/users.go +++ b/pkg/satellite/users.go @@ -17,7 +17,7 @@ type Users interface { // Get is a method for querying user from the database by id Get(ctx context.Context, id uuid.UUID) (*User, error) // Insert is a method for inserting user into the database - Insert(ctx context.Context, user *User) error + Insert(ctx context.Context, user *User) (*User, error) // Delete is a method for deleting user by Id from the database. Delete(ctx context.Context, id uuid.UUID) error // Update is a method for updating user entity