V3-665 Creating Companies repository (#606)

This commit is contained in:
Yehor Butko 2018-11-09 14:05:24 +02:00 committed by GitHub
parent f11f4653e4
commit 7dcbba2541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 983 additions and 90 deletions

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package satellite
import (
"context"
"time"
"github.com/skyrings/skyring-common/tools/uuid"
)
// Companies exposes methods to manage Company table in database.
type Companies interface {
// Get is a method for querying company from the database by id
Get(ctx context.Context, id uuid.UUID) (*Company, error)
// Get is a method for querying company from the database by user id
GetByUserID(ctx context.Context, userID uuid.UUID) (*Company, error)
// Insert is a method for inserting company into the database
Insert(ctx context.Context, company *Company) (*Company, error)
// Delete is a method for deleting company by Id from the database.
Delete(ctx context.Context, id uuid.UUID) error
// Update is a method for updating company entity
Update(ctx context.Context, company *Company) error
}
// Company is a database object that describes Company entity
type Company struct {
ID uuid.UUID
UserID uuid.UUID
Name string
Address string
Country string
City string
State string
PostalCode string
CreatedAt time.Time
}

View File

@ -7,6 +7,8 @@ package satellite
type DB interface {
// Users is getter for Users repository
Users() Users
// Companies is getter for Companies repository
Companies() Companies
// CreateTables is a method for creating all tables for satellitedb
CreateTables() error

View File

@ -0,0 +1,129 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"context"
"github.com/zeebo/errs"
"github.com/skyrings/skyring-common/tools/uuid"
"storj.io/storj/pkg/satellite"
"storj.io/storj/pkg/satellite/satellitedb/dbx"
)
// implementation of Companies interface repository using spacemonkeygo/dbx orm
type companies struct {
db *dbx.DB
}
// Get is a method for querying company from the database by id
func (companies *companies) Get(ctx context.Context, id uuid.UUID) (*satellite.Company, error) {
company, err := companies.db.Get_Company_By_Id(ctx, dbx.Company_Id([]byte(id.String())))
if err != nil {
return nil, err
}
return companyFromDBX(company)
}
// Get is a method for querying company from the database by user id
func (companies *companies) GetByUserID(ctx context.Context, userID uuid.UUID) (*satellite.Company, error) {
company, err := companies.db.Get_Company_By_UserId(ctx, dbx.Company_UserId([]byte(userID.String())))
if err != nil {
return nil, err
}
return companyFromDBX(company)
}
// Insert is a method for inserting company into the database
func (companies *companies) Insert(ctx context.Context, company *satellite.Company) (*satellite.Company, error) {
companyID, err := uuid.New()
if err != nil {
return nil, err
}
createdCompany, err := companies.db.Create_Company(
ctx,
dbx.Company_Id([]byte(companyID.String())),
dbx.Company_UserId([]byte(company.UserID.String())),
dbx.Company_Name(company.Name),
dbx.Company_Address(company.Address),
dbx.Company_Country(company.Country),
dbx.Company_City(company.City),
dbx.Company_State(company.State),
dbx.Company_PostalCode(company.PostalCode))
if err != nil {
return nil, err
}
return companyFromDBX(createdCompany)
}
// Delete is a method for deleting company by Id from the database.
func (companies *companies) Delete(ctx context.Context, id uuid.UUID) error {
_, err := companies.db.Delete_Company_By_Id(ctx, dbx.Company_Id([]byte(id.String())))
return err
}
// Update is a method for updating company entity
func (companies *companies) Update(ctx context.Context, company *satellite.Company) error {
_, err := companies.db.Update_Company_By_Id(
ctx,
dbx.Company_Id([]byte(company.ID.String())),
*getCompanyUpdateFields(company))
return err
}
// companyFromDBX is used for creating Company entity from autogenerated dbx.Company struct
func companyFromDBX(company *dbx.Company) (*satellite.Company, error) {
if company == nil {
return nil, errs.New("company parameter is nil")
}
id, err := uuid.Parse(string(company.Id))
if err != nil {
return nil, errs.New("Id in not valid UUID string")
}
userID, err := uuid.Parse(string(company.UserId))
if err != nil {
return nil, errs.New("UserID in not valid UUID string")
}
comp := &satellite.Company{
ID: *id,
UserID: *userID,
Name: company.Name,
Address: company.Address,
Country: company.Country,
City: company.City,
State: company.State,
PostalCode: company.PostalCode,
CreatedAt: company.CreatedAt,
}
return comp, nil
}
// getCompanyUpdateFields is used to generate company update fields
func getCompanyUpdateFields(company *satellite.Company) *dbx.Company_Update_Fields {
return &dbx.Company_Update_Fields{
Name: dbx.Company_Name(company.Name),
Address: dbx.Company_Address(company.Address),
Country: dbx.Company_Country(company.Country),
City: dbx.Company_City(company.City),
State: dbx.Company_State(company.State),
PostalCode: dbx.Company_PostalCode(company.PostalCode),
}
}

View File

@ -0,0 +1,238 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"testing"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/stretchr/testify/assert"
"storj.io/storj/internal/testcontext"
"storj.io/storj/pkg/satellite"
"storj.io/storj/pkg/satellite/satellitedb/dbx"
)
func TestCompanyRepository(t *testing.T) {
//testing constants
const (
// for user
lastName = "lastName"
email = "email@ukr.net"
pass = "123456"
userName = "name"
// for company
companyName = "Storj"
address = "somewhere"
country = "USA"
city = "Atlanta"
state = "Georgia"
postalCode = "02183"
// updated company values
newCompanyName = "Storage"
newAddress = "where"
newCountry = "Usa"
newCity = "Otlanta"
newState = "Jeorgia"
newPostalCode = "02184"
)
ctx := testcontext.New(t)
defer ctx.Cleanup()
// creating in-memory db and opening connection
// to test with real db3 file use this connection string - "../db/accountdb.db3"
db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared")
if err != nil {
assert.NoError(t, err)
}
defer ctx.Check(db.Close)
// creating tables
err = db.CreateTables()
if err != nil {
assert.NoError(t, err)
}
// repositories
users := db.Users()
companies := db.Companies()
var user *satellite.User
t.Run("Can't insert company without user", func(t *testing.T) {
company := &satellite.Company{
Name: companyName,
Address: address,
Country: country,
City: city,
State: state,
PostalCode: postalCode,
}
createdCompany, err := companies.Insert(ctx, company)
assert.Nil(t, createdCompany)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("Insert company successfully", func(t *testing.T) {
user, err = users.Insert(ctx, &satellite.User{
FirstName: userName,
LastName: lastName,
Email: email,
PasswordHash: []byte(pass),
})
assert.NoError(t, err)
assert.NotNil(t, user)
company := &satellite.Company{
UserID: user.ID,
Name: companyName,
Address: address,
Country: country,
City: city,
State: state,
PostalCode: postalCode,
}
createdCompany, err := companies.Insert(ctx, company)
assert.NotNil(t, createdCompany)
assert.Nil(t, err)
assert.NoError(t, err)
})
t.Run("Get company success", func(t *testing.T) {
companyByUserID, err := companies.GetByUserID(ctx, user.ID)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, companyByUserID.UserID, user.ID)
assert.Equal(t, companyByUserID.Name, companyName)
assert.Equal(t, companyByUserID.Address, address)
assert.Equal(t, companyByUserID.Country, country)
assert.Equal(t, companyByUserID.City, city)
assert.Equal(t, companyByUserID.State, state)
assert.Equal(t, companyByUserID.PostalCode, postalCode)
companyByID, err := companies.Get(ctx, companyByUserID.ID)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, companyByID.ID, companyByUserID.ID)
assert.Equal(t, companyByID.UserID, user.ID)
assert.Equal(t, companyByID.Name, companyName)
assert.Equal(t, companyByID.Address, address)
assert.Equal(t, companyByID.Country, country)
assert.Equal(t, companyByID.City, city)
assert.Equal(t, companyByID.State, state)
assert.Equal(t, companyByID.PostalCode, postalCode)
})
t.Run("Update company success", func(t *testing.T) {
oldCompany, err := companies.GetByUserID(ctx, user.ID)
assert.NoError(t, err)
assert.NotNil(t, oldCompany)
// creating new company with updated values
newCompany := &satellite.Company{
ID: oldCompany.ID,
UserID: user.ID,
Name: newCompanyName,
Address: newAddress,
Country: newCountry,
City: newCity,
State: newState,
PostalCode: newPostalCode,
}
err = companies.Update(ctx, newCompany)
assert.Nil(t, err)
assert.NoError(t, err)
// fetching updated company from db
newCompany, err = companies.Get(ctx, oldCompany.ID)
assert.NoError(t, err)
assert.Equal(t, newCompany.ID, oldCompany.ID)
assert.Equal(t, newCompany.UserID, user.ID)
assert.Equal(t, newCompany.Name, newCompanyName)
assert.Equal(t, newCompany.Address, newAddress)
assert.Equal(t, newCompany.Country, newCountry)
assert.Equal(t, newCompany.City, newCity)
assert.Equal(t, newCompany.State, newState)
assert.Equal(t, newCompany.PostalCode, newPostalCode)
})
t.Run("Delete company success", func(t *testing.T) {
oldCompany, err := companies.GetByUserID(ctx, user.ID)
assert.NoError(t, err)
assert.NotNil(t, oldCompany)
err = companies.Delete(ctx, oldCompany.ID)
assert.Nil(t, err)
assert.NoError(t, err)
_, err = companies.Get(ctx, oldCompany.ID)
assert.NotNil(t, err)
assert.Error(t, err)
})
}
func TestCompanyFromDbx(t *testing.T) {
t.Run("can't create dbo from nil dbx model", func(t *testing.T) {
user, err := companyFromDBX(nil)
assert.Nil(t, user)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("can't create dbo from dbx model with invalid ID", func(t *testing.T) {
dbxCompany := dbx.Company{
Id: []byte("qweqwe"),
}
user, err := companyFromDBX(&dbxCompany)
assert.Nil(t, user)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("can't create dbo from dbx model with invalid UserID", func(t *testing.T) {
companyID, err := uuid.New()
assert.NoError(t, err)
assert.Nil(t, err)
dbxCompany := dbx.Company{
Id: []byte(companyID.String()),
UserId: []byte("qweqwe"),
}
user, err := companyFromDBX(&dbxCompany)
assert.Nil(t, user)
assert.NotNil(t, err)
assert.Error(t, err)
})
}

View File

@ -34,6 +34,11 @@ func (db *Database) Users() satellite.Users {
return &users{db.db}
}
// Companies is getter for Companies repository
func (db *Database) Companies() satellite.Companies {
return &companies{db.db}
}
// CreateTables is a method for creating all tables for satellitedb
func (db *Database) CreateTables() error {
_, err := db.db.Exec(db.db.Schema())

View File

@ -1,35 +0,0 @@
// dbx.v1 golang satellitedb.dbx .
model user (
key id
unique email
field id text
field first_name text ( updatable )
field last_name text ( updatable )
field email text ( updatable )
field password_hash blob ( updatable )
field created_at timestamp ( autoinsert )
)
read one (
select user
where user.email = ?
where user.password_hash = ?
)
read one (
select user
where user.id = ?
)
create user ( )
update user ( where user.id = ? )
delete user ( where user.id = ? )
//TODO: this entity will be used and updated in the next commit
model company (
key id
field id blob
field userId user.id cascade ( updatable )
)

View File

@ -0,0 +1,56 @@
// dbx.v1 golang satellitedb.dbx .
model user (
key id
unique email
field id blob
field first_name text ( updatable )
field last_name text ( updatable )
field email text ( updatable )
field password_hash blob ( updatable )
field created_at timestamp ( autoinsert )
)
read one (
select user
where user.email = ?
where user.password_hash = ?
)
read one (
select user
where user.id = ?
)
create user ( )
update user ( where user.id = ? )
delete user ( where user.id = ? )
model company (
key id
field id blob
field user_id user.id cascade
field name text ( updatable )
field address text ( updatable )
field country text ( updatable )
field city text ( updatable )
field state text ( updatable )
field postal_code text ( updatable )
field created_at timestamp ( autoinsert )
)
read one (
select company
where company.user_id = ?
)
read one (
select company
where company.id = ?
)
create company ( )
update company ( where company.id = ? )
delete company ( where company.id = ? )

View File

@ -267,7 +267,7 @@ func newsqlite3(db *DB) *sqlite3DB {
func (obj *sqlite3DB) Schema() string {
return `CREATE TABLE users (
id TEXT NOT NULL,
id BLOB NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
email TEXT NOT NULL,
@ -278,7 +278,14 @@ func (obj *sqlite3DB) Schema() string {
);
CREATE TABLE companies (
id BLOB NOT NULL,
userId TEXT NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
name TEXT NOT NULL,
address TEXT NOT NULL,
country TEXT NOT NULL,
city TEXT NOT NULL,
state TEXT NOT NULL,
postal_code TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
PRIMARY KEY ( id )
);`
}
@ -344,7 +351,7 @@ nextval:
}
type User struct {
Id string
Id []byte
FirstName string
LastName string
Email string
@ -363,10 +370,10 @@ type User_Update_Fields struct {
type User_Id_Field struct {
_set bool
_value string
_value []byte
}
func User_Id(v string) User_Id_Field {
func User_Id(v []byte) User_Id_Field {
return User_Id_Field{_set: true, _value: v}
}
@ -470,14 +477,26 @@ func (f User_CreatedAt_Field) value() interface{} {
func (User_CreatedAt_Field) _Column() string { return "created_at" }
type Company struct {
Id []byte
UserId string
Id []byte
UserId []byte
Name string
Address string
Country string
City string
State string
PostalCode string
CreatedAt time.Time
}
func (Company) _Table() string { return "companies" }
type Company_Update_Fields struct {
UserId Company_UserId_Field
Name Company_Name_Field
Address Company_Address_Field
Country Company_Country_Field
City Company_City_Field
State Company_State_Field
PostalCode Company_PostalCode_Field
}
type Company_Id_Field struct {
@ -500,10 +519,10 @@ func (Company_Id_Field) _Column() string { return "id" }
type Company_UserId_Field struct {
_set bool
_value string
_value []byte
}
func Company_UserId(v string) Company_UserId_Field {
func Company_UserId(v []byte) Company_UserId_Field {
return Company_UserId_Field{_set: true, _value: v}
}
@ -514,7 +533,133 @@ func (f Company_UserId_Field) value() interface{} {
return f._value
}
func (Company_UserId_Field) _Column() string { return "userId" }
func (Company_UserId_Field) _Column() string { return "user_id" }
type Company_Name_Field struct {
_set bool
_value string
}
func Company_Name(v string) Company_Name_Field {
return Company_Name_Field{_set: true, _value: v}
}
func (f Company_Name_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_Name_Field) _Column() string { return "name" }
type Company_Address_Field struct {
_set bool
_value string
}
func Company_Address(v string) Company_Address_Field {
return Company_Address_Field{_set: true, _value: v}
}
func (f Company_Address_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_Address_Field) _Column() string { return "address" }
type Company_Country_Field struct {
_set bool
_value string
}
func Company_Country(v string) Company_Country_Field {
return Company_Country_Field{_set: true, _value: v}
}
func (f Company_Country_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_Country_Field) _Column() string { return "country" }
type Company_City_Field struct {
_set bool
_value string
}
func Company_City(v string) Company_City_Field {
return Company_City_Field{_set: true, _value: v}
}
func (f Company_City_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_City_Field) _Column() string { return "city" }
type Company_State_Field struct {
_set bool
_value string
}
func Company_State(v string) Company_State_Field {
return Company_State_Field{_set: true, _value: v}
}
func (f Company_State_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_State_Field) _Column() string { return "state" }
type Company_PostalCode_Field struct {
_set bool
_value string
}
func Company_PostalCode(v string) Company_PostalCode_Field {
return Company_PostalCode_Field{_set: true, _value: v}
}
func (f Company_PostalCode_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_PostalCode_Field) _Column() string { return "postal_code" }
type Company_CreatedAt_Field struct {
_set bool
_value time.Time
}
func Company_CreatedAt(v time.Time) Company_CreatedAt_Field {
return Company_CreatedAt_Field{_set: true, _value: v}
}
func (f Company_CreatedAt_Field) value() interface{} {
if !f._set {
return nil
}
return f._value
}
func (Company_CreatedAt_Field) _Column() string { return "created_at" }
func toUTC(t time.Time) time.Time {
return t.UTC()
@ -717,6 +862,45 @@ func (obj *sqlite3Impl) Create_User(ctx context.Context,
}
func (obj *sqlite3Impl) Create_Company(ctx context.Context,
company_id Company_Id_Field,
company_user_id Company_UserId_Field,
company_name Company_Name_Field,
company_address Company_Address_Field,
company_country Company_Country_Field,
company_city Company_City_Field,
company_state Company_State_Field,
company_postal_code Company_PostalCode_Field) (
company *Company, err error) {
__now := obj.db.Hooks.Now().UTC()
__id_val := company_id.value()
__user_id_val := company_user_id.value()
__name_val := company_name.value()
__address_val := company_address.value()
__country_val := company_country.value()
__city_val := company_city.value()
__state_val := company_state.value()
__postal_code_val := company_postal_code.value()
__created_at_val := __now
var __embed_stmt = __sqlbundle_Literal("INSERT INTO companies ( id, user_id, name, address, country, city, state, postal_code, created_at ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? )")
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __id_val, __user_id_val, __name_val, __address_val, __country_val, __city_val, __state_val, __postal_code_val, __created_at_val)
__res, err := obj.driver.Exec(__stmt, __id_val, __user_id_val, __name_val, __address_val, __country_val, __city_val, __state_val, __postal_code_val, __created_at_val)
if err != nil {
return nil, obj.makeErr(err)
}
__pk, err := __res.LastInsertId()
if err != nil {
return nil, obj.makeErr(err)
}
return obj.getLastCompany(ctx, __pk)
}
func (obj *sqlite3Impl) Get_User_By_Email_And_PasswordHash(ctx context.Context,
user_email User_Email_Field,
user_password_hash User_PasswordHash_Field) (
@ -760,6 +944,70 @@ func (obj *sqlite3Impl) Get_User_By_Id(ctx context.Context,
}
func (obj *sqlite3Impl) Get_Company_By_UserId(ctx context.Context,
company_user_id Company_UserId_Field) (
company *Company, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.user_id = ? LIMIT 2")
var __values []interface{}
__values = append(__values, company_user_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
__rows, err := obj.driver.Query(__stmt, __values...)
if err != nil {
return nil, obj.makeErr(err)
}
defer __rows.Close()
if !__rows.Next() {
if err := __rows.Err(); err != nil {
return nil, obj.makeErr(err)
}
return nil, makeErr(sql.ErrNoRows)
}
company = &Company{}
err = __rows.Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
if __rows.Next() {
return nil, tooManyRows("Company_By_UserId")
}
if err := __rows.Err(); err != nil {
return nil, obj.makeErr(err)
}
return company, nil
}
func (obj *sqlite3Impl) Get_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
company *Company, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.id = ?")
var __values []interface{}
__values = append(__values, company_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
company = &Company{}
err = obj.driver.QueryRow(__stmt, __values...).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
return company, nil
}
func (obj *sqlite3Impl) Update_User_By_Id(ctx context.Context,
user_id User_Id_Field,
update User_Update_Fields) (
@ -825,6 +1073,81 @@ func (obj *sqlite3Impl) Update_User_By_Id(ctx context.Context,
return user, nil
}
func (obj *sqlite3Impl) Update_Company_By_Id(ctx context.Context,
company_id Company_Id_Field,
update Company_Update_Fields) (
company *Company, err error) {
var __sets = &__sqlbundle_Hole{}
var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE companies SET "), __sets, __sqlbundle_Literal(" WHERE companies.id = ?")}}
__sets_sql := __sqlbundle_Literals{Join: ", "}
var __values []interface{}
var __args []interface{}
if update.Name._set {
__values = append(__values, update.Name.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("name = ?"))
}
if update.Address._set {
__values = append(__values, update.Address.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("address = ?"))
}
if update.Country._set {
__values = append(__values, update.Country.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("country = ?"))
}
if update.City._set {
__values = append(__values, update.City.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("city = ?"))
}
if update.State._set {
__values = append(__values, update.State.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("state = ?"))
}
if update.PostalCode._set {
__values = append(__values, update.PostalCode.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("postal_code = ?"))
}
if len(__sets_sql.SQLs) == 0 {
return nil, emptyUpdate()
}
__args = append(__args, company_id.value())
__values = append(__values, __args...)
__sets.SQL = __sets_sql
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
company = &Company{}
_, err = obj.driver.Exec(__stmt, __values...)
if err != nil {
return nil, obj.makeErr(err)
}
var __embed_stmt_get = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE companies.id = ?")
var __stmt_get = __sqlbundle_Render(obj.dialect, __embed_stmt_get)
obj.logStmt("(IMPLIED) "+__stmt_get, __args...)
err = obj.driver.QueryRow(__stmt_get, __args...).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, obj.makeErr(err)
}
return company, nil
}
func (obj *sqlite3Impl) Delete_User_By_Id(ctx context.Context,
user_id User_Id_Field) (
deleted bool, err error) {
@ -851,6 +1174,32 @@ func (obj *sqlite3Impl) Delete_User_By_Id(ctx context.Context,
}
func (obj *sqlite3Impl) Delete_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
deleted bool, err error) {
var __embed_stmt = __sqlbundle_Literal("DELETE FROM companies WHERE companies.id = ?")
var __values []interface{}
__values = append(__values, company_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
__res, err := obj.driver.Exec(__stmt, __values...)
if err != nil {
return false, obj.makeErr(err)
}
__count, err := __res.RowsAffected()
if err != nil {
return false, obj.makeErr(err)
}
return __count > 0, nil
}
func (obj *sqlite3Impl) getLastUser(ctx context.Context,
pk int64) (
user *User, err error) {
@ -869,6 +1218,24 @@ func (obj *sqlite3Impl) getLastUser(ctx context.Context,
}
func (obj *sqlite3Impl) getLastCompany(ctx context.Context,
pk int64) (
company *Company, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT companies.id, companies.user_id, companies.name, companies.address, companies.country, companies.city, companies.state, companies.postal_code, companies.created_at FROM companies WHERE _rowid_ = ?")
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, pk)
company = &Company{}
err = obj.driver.QueryRow(__stmt, pk).Scan(&company.Id, &company.UserId, &company.Name, &company.Address, &company.Country, &company.City, &company.State, &company.PostalCode, &company.CreatedAt)
if err != nil {
return nil, obj.makeErr(err)
}
return company, nil
}
func (impl sqlite3Impl) isConstraintError(err error) (
constraint string, ok bool) {
if e, ok := err.(sqlite3.Error); ok {
@ -954,6 +1321,24 @@ func (rx *Rx) Rollback() (err error) {
return err
}
func (rx *Rx) Create_Company(ctx context.Context,
company_id Company_Id_Field,
company_user_id Company_UserId_Field,
company_name Company_Name_Field,
company_address Company_Address_Field,
company_country Company_Country_Field,
company_city Company_City_Field,
company_state Company_State_Field,
company_postal_code Company_PostalCode_Field) (
company *Company, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Create_Company(ctx, company_id, company_user_id, company_name, company_address, company_country, company_city, company_state, company_postal_code)
}
func (rx *Rx) Create_User(ctx context.Context,
user_id User_Id_Field,
user_first_name User_FirstName_Field,
@ -969,6 +1354,16 @@ func (rx *Rx) Create_User(ctx context.Context,
}
func (rx *Rx) Delete_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
deleted bool, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Delete_Company_By_Id(ctx, company_id)
}
func (rx *Rx) Delete_User_By_Id(ctx context.Context,
user_id User_Id_Field) (
deleted bool, err error) {
@ -979,6 +1374,26 @@ func (rx *Rx) Delete_User_By_Id(ctx context.Context,
return tx.Delete_User_By_Id(ctx, user_id)
}
func (rx *Rx) Get_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
company *Company, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Get_Company_By_Id(ctx, company_id)
}
func (rx *Rx) Get_Company_By_UserId(ctx context.Context,
company_user_id Company_UserId_Field) (
company *Company, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Get_Company_By_UserId(ctx, company_user_id)
}
func (rx *Rx) Get_User_By_Email_And_PasswordHash(ctx context.Context,
user_email User_Email_Field,
user_password_hash User_PasswordHash_Field) (
@ -1000,6 +1415,17 @@ func (rx *Rx) Get_User_By_Id(ctx context.Context,
return tx.Get_User_By_Id(ctx, user_id)
}
func (rx *Rx) Update_Company_By_Id(ctx context.Context,
company_id Company_Id_Field,
update Company_Update_Fields) (
company *Company, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Update_Company_By_Id(ctx, company_id, update)
}
func (rx *Rx) Update_User_By_Id(ctx context.Context,
user_id User_Id_Field,
update User_Update_Fields) (
@ -1012,6 +1438,17 @@ func (rx *Rx) Update_User_By_Id(ctx context.Context,
}
type Methods interface {
Create_Company(ctx context.Context,
company_id Company_Id_Field,
company_user_id Company_UserId_Field,
company_name Company_Name_Field,
company_address Company_Address_Field,
company_country Company_Country_Field,
company_city Company_City_Field,
company_state Company_State_Field,
company_postal_code Company_PostalCode_Field) (
company *Company, err error)
Create_User(ctx context.Context,
user_id User_Id_Field,
user_first_name User_FirstName_Field,
@ -1020,10 +1457,22 @@ type Methods interface {
user_password_hash User_PasswordHash_Field) (
user *User, err error)
Delete_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
deleted bool, err error)
Delete_User_By_Id(ctx context.Context,
user_id User_Id_Field) (
deleted bool, err error)
Get_Company_By_Id(ctx context.Context,
company_id Company_Id_Field) (
company *Company, err error)
Get_Company_By_UserId(ctx context.Context,
company_user_id Company_UserId_Field) (
company *Company, err error)
Get_User_By_Email_And_PasswordHash(ctx context.Context,
user_email User_Email_Field,
user_password_hash User_PasswordHash_Field) (
@ -1033,6 +1482,11 @@ type Methods interface {
user_id User_Id_Field) (
user *User, err error)
Update_Company_By_Id(ctx context.Context,
company_id Company_Id_Field,
update Company_Update_Fields) (
company *Company, err error)
Update_User_By_Id(ctx context.Context,
user_id User_Id_Field,
update User_Update_Fields) (

View File

@ -1,7 +1,7 @@
-- AUTOGENERATED BY gopkg.in/spacemonkeygo/dbx.v1
-- DO NOT EDIT
CREATE TABLE users (
id TEXT NOT NULL,
id BLOB NOT NULL,
first_name TEXT NOT NULL,
last_name TEXT NOT NULL,
email TEXT NOT NULL,
@ -12,6 +12,13 @@ CREATE TABLE users (
);
CREATE TABLE companies (
id BLOB NOT NULL,
userId TEXT NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
user_id BLOB NOT NULL REFERENCES users( id ) ON DELETE CASCADE,
name TEXT NOT NULL,
address TEXT NOT NULL,
country TEXT NOT NULL,
city TEXT NOT NULL,
state TEXT NOT NULL,
postal_code TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
PRIMARY KEY ( id )
);

View File

@ -14,7 +14,7 @@ import (
"storj.io/storj/pkg/satellite/satellitedb/dbx"
)
// implementation of User interface repository using spacemonkeygo/dbx orm
// implementation of Users interface repository using spacemonkeygo/dbx orm
type users struct {
db *dbx.DB
}
@ -22,7 +22,7 @@ type users struct {
// Get is a method for querying user from the database by id
func (users *users) Get(ctx context.Context, id uuid.UUID) (*satellite.User, error) {
userID := dbx.User_Id(id.String())
userID := dbx.User_Id([]byte(id.String()))
user, err := users.db.Get_User_By_Id(ctx, userID)
@ -49,20 +49,30 @@ func (users *users) GetByCredentials(ctx context.Context, password []byte, email
}
// Insert is a method for inserting user into the database
func (users *users) Insert(ctx context.Context, user *satellite.User) error {
_, err := users.db.Create_User(ctx,
dbx.User_Id(user.ID.String()),
func (users *users) Insert(ctx context.Context, user *satellite.User) (*satellite.User, error) {
userID, err := uuid.New()
if err != nil {
return nil, err
}
createdUser, err := users.db.Create_User(ctx,
dbx.User_Id([]byte(userID.String())),
dbx.User_FirstName(user.FirstName),
dbx.User_LastName(user.LastName),
dbx.User_Email(user.Email),
dbx.User_PasswordHash(user.PasswordHash))
return err
if err != nil {
return nil, err
}
return userFromDBX(createdUser)
}
// Delete is a method for deleting user by Id from the database.
func (users *users) Delete(ctx context.Context, id uuid.UUID) error {
_, err := users.db.Delete_User_By_Id(ctx, dbx.User_Id(id.String()))
_, err := users.db.Delete_User_By_Id(ctx, dbx.User_Id([]byte(id.String())))
return err
}
@ -70,7 +80,7 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) error {
// Update is a method for updating user entity
func (users *users) Update(ctx context.Context, user *satellite.User) error {
_, err := users.db.Update_User_By_Id(ctx,
dbx.User_Id(user.ID.String()),
dbx.User_Id([]byte(user.ID.String())),
dbx.User_Update_Fields{
FirstName: dbx.User_FirstName(user.FirstName),
LastName: dbx.User_LastName(user.LastName),
@ -88,20 +98,19 @@ func userFromDBX(user *dbx.User) (*satellite.User, error) {
return nil, errs.New("user parameter is nil")
}
id, err := uuid.Parse(user.Id)
id, err := uuid.Parse(string(user.Id))
if err != nil {
return nil, errs.New("Id in not valid UUID string")
}
u := &satellite.User{}
u.ID = *id
u.FirstName = user.FirstName
u.LastName = user.LastName
u.Email = user.Email
u.PasswordHash = user.PasswordHash
u.CreatedAt = user.CreatedAt
u := &satellite.User{
ID: *id,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email,
PasswordHash: user.PasswordHash,
CreatedAt: user.CreatedAt,
}
return u, nil
}

View File

@ -16,7 +16,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestRepository(t *testing.T) {
func TestUserRepository(t *testing.T) {
//testing constants
const (
@ -33,6 +33,7 @@ func TestRepository(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
// creating in-memory db and opens connection
// to test with real db3 file use this connection string - "../db/accountdb.db3"
db, err := New("sqlite3", "file::memory:?mode=memory&cache=shared")
if err != nil {
@ -40,6 +41,7 @@ func TestRepository(t *testing.T) {
}
defer ctx.Check(db.Close)
// creating tables
err = db.CreateTables()
if err != nil {
assert.NoError(t, err)
@ -64,7 +66,7 @@ func TestRepository(t *testing.T) {
CreatedAt: time.Now(),
}
err = repository.Insert(ctx, user)
_, err = repository.Insert(ctx, user)
assert.Nil(t, err)
assert.NoError(t, err)
@ -72,14 +74,7 @@ func TestRepository(t *testing.T) {
t.Run("Can't insert user with same email twice", func(t *testing.T) {
id, err := uuid.New()
if err != nil {
assert.NoError(t, err)
}
user := &satellite.User{
ID: *id,
FirstName: name,
LastName: lastName,
Email: email,
@ -87,7 +82,7 @@ func TestRepository(t *testing.T) {
CreatedAt: time.Now(),
}
err = repository.Insert(ctx, user)
_, err = repository.Insert(ctx, user)
assert.NotNil(t, err)
assert.Error(t, err)
@ -119,9 +114,7 @@ func TestRepository(t *testing.T) {
t.Run("Update user success", func(t *testing.T) {
oldUser, err := repository.GetByCredentials(ctx, []byte(passValid), email)
if err != nil {
assert.NoError(t, err)
}
assert.NoError(t, err)
newUser := &satellite.User{
ID: oldUser.ID,
@ -129,7 +122,6 @@ func TestRepository(t *testing.T) {
LastName: newLastName,
Email: newEmail,
PasswordHash: []byte(newPass),
CreatedAt: oldUser.CreatedAt,
}
err = repository.Update(ctx, newUser)
@ -139,9 +131,7 @@ func TestRepository(t *testing.T) {
newUser, err = repository.Get(ctx, oldUser.ID)
if err != nil {
assert.NoError(t, err)
}
assert.NoError(t, err)
assert.Equal(t, newUser.ID, oldUser.ID)
assert.Equal(t, newUser.FirstName, newName)
@ -154,9 +144,7 @@ func TestRepository(t *testing.T) {
t.Run("Delete user success", func(t *testing.T) {
oldUser, err := repository.GetByCredentials(ctx, []byte(newPass), newEmail)
if err != nil {
assert.NoError(t, err)
}
assert.NoError(t, err)
err = repository.Delete(ctx, oldUser.ID)
@ -170,7 +158,7 @@ func TestRepository(t *testing.T) {
})
}
func TestUserDboFromDbx(t *testing.T) {
func TestUserFromDbx(t *testing.T) {
t.Run("can't create dbo from nil dbx model", func(t *testing.T) {
user, err := userFromDBX(nil)
@ -180,9 +168,9 @@ func TestUserDboFromDbx(t *testing.T) {
assert.Error(t, err)
})
t.Run("can't create dbo from dbx model with invalid Id", func(t *testing.T) {
t.Run("can't create dbo from dbx model with invalid ID", func(t *testing.T) {
dbxUser := dbx.User{
Id: "qweqwe",
Id: []byte("qweqwe"),
FirstName: "FirstName",
LastName: "LastName",
Email: "email@ukr.net",

View File

@ -17,7 +17,7 @@ type Users interface {
// Get is a method for querying user from the database by id
Get(ctx context.Context, id uuid.UUID) (*User, error)
// Insert is a method for inserting user into the database
Insert(ctx context.Context, user *User) error
Insert(ctx context.Context, user *User) (*User, error)
// Delete is a method for deleting user by Id from the database.
Delete(ctx context.Context, id uuid.UUID) error
// Update is a method for updating user entity