multinode/db: nodes repository tests added

Change-Id: Ia5172f249c18540683f66ef244c2c6d39aa3da0a
This commit is contained in:
crawter 2020-11-16 14:46:49 +02:00
parent 51fa52e636
commit f311722854
6 changed files with 221 additions and 90 deletions

View File

@ -7,6 +7,8 @@ import (
"context"
"encoding/base64"
"github.com/zeebo/errs"
"storj.io/common/storj"
)
@ -16,25 +18,26 @@ import (
//
// architecture: Database
type Nodes interface {
// Add creates new node in NodesDB.
Add(ctx context.Context, id storj.NodeID, apiSecret []byte, publicAddress string) error
// GetByID return node from NodesDB by its id.
GetByID(ctx context.Context, id storj.NodeID) (Node, error)
// GetAll returns all connected nodes.
GetAll(ctx context.Context) ([]Node, error)
// Add creates new node in NodesDB.
Add(ctx context.Context, id storj.NodeID, apiSecret []byte, publicAddress string) error
// Remove removed node from NodesDB.
Remove(ctx context.Context, id storj.NodeID) error
}
// ErrNoNode is a special error type that indicates about absence of node in NodesDB.
var ErrNoNode = errs.Class("no such node")
// Node is a representation of storeganode, that SNO could add to the Multinode Dashboard.
type Node struct {
ID storj.NodeID
// APISecret is a secret issued by storagenode, that will be main auth mechanism in MND <-> SNO api. is a secret issued by storagenode, that will be main auth mechanism in MND <-> SNO api.
APISecret []byte
PublicAddress string
// Logo is a configurable icon.
Logo []byte
// Tag is configured by used and could be used to group nodes. // TODO: should node have multiple tags?
Tag string // TODO: create enum or type in future.
Name string
}
// APISecretFromBase64 decodes API secret from base 64 string.

View File

@ -0,0 +1,53 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package console_test
import (
"testing"
"github.com/stretchr/testify/assert"
"storj.io/common/testcontext"
"storj.io/common/testrand"
"storj.io/storj/multinode"
"storj.io/storj/multinode/console"
"storj.io/storj/multinode/multinodedb/multinodedbtest"
)
func TestNodesDB(t *testing.T) {
multinodedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db multinode.DB) {
nodes := db.Nodes()
nodeID := testrand.NodeID()
apiSecret := []byte("secret")
publicAddress := "228.13.38.1:8081"
err := nodes.Add(ctx, nodeID, apiSecret, publicAddress)
assert.NoError(t, err)
node, err := nodes.GetByID(ctx, nodeID)
assert.NoError(t, err)
assert.Equal(t, node.ID.Bytes(), nodeID.Bytes())
assert.Equal(t, node.APISecret, apiSecret)
assert.Equal(t, node.PublicAddress, publicAddress)
allNodes, err := nodes.GetAll(ctx)
assert.NoError(t, err)
assert.Equal(t, len(allNodes), 1)
assert.Equal(t, node.ID.Bytes(), allNodes[0].ID.Bytes())
assert.Equal(t, node.APISecret, allNodes[0].APISecret)
assert.Equal(t, node.PublicAddress, allNodes[0].PublicAddress)
err = nodes.Remove(ctx, nodeID)
assert.NoError(t, err)
_, err = nodes.GetAll(ctx)
assert.Error(t, err)
assert.True(t, console.ErrNoNode.Has(err))
node, err = nodes.GetByID(ctx, nodeID)
assert.Error(t, err)
assert.True(t, console.ErrNoNode.Has(err))
})
}

View File

@ -5,19 +5,21 @@ model node (
field id blob
field name text ( updatable )
field tag text ( updatable )
field public_address text
field api_secret blob
field logo blob ( updatable )
)
create node ( )
delete node ( where node.id = ? )
update node ( where node.id = ? )
read one (
select node
where node.id = ?
)
read all(
select node
)
model member (
key id

View File

@ -280,10 +280,8 @@ func (obj *pgxDB) Schema() string {
CREATE TABLE nodes (
id bytea NOT NULL,
name text NOT NULL,
tag text NOT NULL,
public_address text NOT NULL,
api_secret bytea NOT NULL,
logo bytea NOT NULL,
PRIMARY KEY ( id )
);`
}
@ -462,18 +460,14 @@ func (Member_CreatedAt_Field) _Column() string { return "created_at" }
type Node struct {
Id []byte
Name string
Tag string
PublicAddress string
ApiSecret []byte
Logo []byte
}
func (Node) _Table() string { return "nodes" }
type Node_Update_Fields struct {
Name Node_Name_Field
Tag Node_Tag_Field
Logo Node_Logo_Field
}
type Node_Id_Field struct {
@ -514,25 +508,6 @@ func (f Node_Name_Field) value() interface{} {
func (Node_Name_Field) _Column() string { return "name" }
type Node_Tag_Field struct {
_set bool
_null bool
_value string
}
func Node_Tag(v string) Node_Tag_Field {
return Node_Tag_Field{_set: true, _value: v}
}
func (f Node_Tag_Field) value() interface{} {
if !f._set || f._null {
return nil
}
return f._value
}
func (Node_Tag_Field) _Column() string { return "tag" }
type Node_PublicAddress_Field struct {
_set bool
_null bool
@ -571,25 +546,6 @@ func (f Node_ApiSecret_Field) value() interface{} {
func (Node_ApiSecret_Field) _Column() string { return "api_secret" }
type Node_Logo_Field struct {
_set bool
_null bool
_value []byte
}
func Node_Logo(v []byte) Node_Logo_Field {
return Node_Logo_Field{_set: true, _value: v}
}
func (f Node_Logo_Field) value() interface{} {
if !f._set || f._null {
return nil
}
return f._value
}
func (Node_Logo_Field) _Column() string { return "logo" }
func toUTC(t time.Time) time.Time {
return t.UTC()
}
@ -1013,29 +969,25 @@ func (h *__sqlbundle_Hole) Render() string {
func (obj *pgxImpl) Create_Node(ctx context.Context,
node_id Node_Id_Field,
node_name Node_Name_Field,
node_tag Node_Tag_Field,
node_public_address Node_PublicAddress_Field,
node_api_secret Node_ApiSecret_Field,
node_logo Node_Logo_Field) (
node_api_secret Node_ApiSecret_Field) (
node *Node, err error) {
defer mon.Task()(&ctx)(&err)
__id_val := node_id.value()
__name_val := node_name.value()
__tag_val := node_tag.value()
__public_address_val := node_public_address.value()
__api_secret_val := node_api_secret.value()
__logo_val := node_logo.value()
var __embed_stmt = __sqlbundle_Literal("INSERT INTO nodes ( id, name, tag, public_address, api_secret, logo ) VALUES ( ?, ?, ?, ?, ?, ? ) RETURNING nodes.id, nodes.name, nodes.tag, nodes.public_address, nodes.api_secret, nodes.logo")
var __embed_stmt = __sqlbundle_Literal("INSERT INTO nodes ( id, name, public_address, api_secret ) VALUES ( ?, ?, ?, ? ) RETURNING nodes.id, nodes.name, nodes.public_address, nodes.api_secret")
var __values []interface{}
__values = append(__values, __id_val, __name_val, __tag_val, __public_address_val, __api_secret_val, __logo_val)
__values = append(__values, __id_val, __name_val, __public_address_val, __api_secret_val)
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
node = &Node{}
err = obj.driver.QueryRowContext(ctx, __stmt, __values...).Scan(&node.Id, &node.Name, &node.Tag, &node.PublicAddress, &node.ApiSecret, &node.Logo)
err = obj.driver.QueryRowContext(ctx, __stmt, __values...).Scan(&node.Id, &node.Name, &node.PublicAddress, &node.ApiSecret)
if err != nil {
return nil, obj.makeErr(err)
}
@ -1080,7 +1032,7 @@ func (obj *pgxImpl) Get_Node_By_Id(ctx context.Context,
node *Node, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT nodes.id, nodes.name, nodes.tag, nodes.public_address, nodes.api_secret, nodes.logo FROM nodes WHERE nodes.id = ?")
var __embed_stmt = __sqlbundle_Literal("SELECT nodes.id, nodes.name, nodes.public_address, nodes.api_secret FROM nodes WHERE nodes.id = ?")
var __values []interface{}
__values = append(__values, node_id.value())
@ -1089,7 +1041,7 @@ func (obj *pgxImpl) Get_Node_By_Id(ctx context.Context,
obj.logStmt(__stmt, __values...)
node = &Node{}
err = obj.driver.QueryRowContext(ctx, __stmt, __values...).Scan(&node.Id, &node.Name, &node.Tag, &node.PublicAddress, &node.ApiSecret, &node.Logo)
err = obj.driver.QueryRowContext(ctx, __stmt, __values...).Scan(&node.Id, &node.Name, &node.PublicAddress, &node.ApiSecret)
if err != nil {
return (*Node)(nil), obj.makeErr(err)
}
@ -1097,6 +1049,38 @@ func (obj *pgxImpl) Get_Node_By_Id(ctx context.Context,
}
func (obj *pgxImpl) All_Node(ctx context.Context) (
rows []*Node, err error) {
defer mon.Task()(&ctx)(&err)
var __embed_stmt = __sqlbundle_Literal("SELECT nodes.id, nodes.name, nodes.public_address, nodes.api_secret FROM nodes")
var __values []interface{}
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, obj.makeErr(err)
}
defer __rows.Close()
for __rows.Next() {
node := &Node{}
err = __rows.Scan(&node.Id, &node.Name, &node.PublicAddress, &node.ApiSecret)
if err != nil {
return nil, obj.makeErr(err)
}
rows = append(rows, node)
}
if err := __rows.Err(); err != nil {
return nil, obj.makeErr(err)
}
return rows, nil
}
func (obj *pgxImpl) Get_Member_By_Email(ctx context.Context,
member_email Member_Email_Field) (
member *Member, err error) {
@ -1163,6 +1147,47 @@ func (obj *pgxImpl) Get_Member_By_Id(ctx context.Context,
}
func (obj *pgxImpl) Update_Node_By_Id(ctx context.Context,
node_id Node_Id_Field,
update Node_Update_Fields) (
node *Node, err error) {
defer mon.Task()(&ctx)(&err)
var __sets = &__sqlbundle_Hole{}
var __embed_stmt = __sqlbundle_Literals{Join: "", SQLs: []__sqlbundle_SQL{__sqlbundle_Literal("UPDATE nodes SET "), __sets, __sqlbundle_Literal(" WHERE nodes.id = ? RETURNING nodes.id, nodes.name, nodes.public_address, nodes.api_secret")}}
__sets_sql := __sqlbundle_Literals{Join: ", "}
var __values []interface{}
var __args []interface{}
if update.Name._set {
__values = append(__values, update.Name.value())
__sets_sql.SQLs = append(__sets_sql.SQLs, __sqlbundle_Literal("name = ?"))
}
if len(__sets_sql.SQLs) == 0 {
return nil, emptyUpdate()
}
__args = append(__args, node_id.value())
__values = append(__values, __args...)
__sets.SQL = __sets_sql
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
node = &Node{}
err = obj.driver.QueryRowContext(ctx, __stmt, __values...).Scan(&node.Id, &node.Name, &node.PublicAddress, &node.ApiSecret)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, obj.makeErr(err)
}
return node, nil
}
func (obj *pgxImpl) Update_Member_By_Id(ctx context.Context,
member_id Member_Id_Field,
update Member_Update_Fields) (
@ -1349,6 +1374,15 @@ func (rx *Rx) Rollback() (err error) {
return err
}
func (rx *Rx) All_Node(ctx context.Context) (
rows []*Node, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.All_Node(ctx)
}
func (rx *Rx) Create_Member(ctx context.Context,
member_id Member_Id_Field,
member_email Member_Email_Field,
@ -1366,16 +1400,14 @@ func (rx *Rx) Create_Member(ctx context.Context,
func (rx *Rx) Create_Node(ctx context.Context,
node_id Node_Id_Field,
node_name Node_Name_Field,
node_tag Node_Tag_Field,
node_public_address Node_PublicAddress_Field,
node_api_secret Node_ApiSecret_Field,
node_logo Node_Logo_Field) (
node_api_secret Node_ApiSecret_Field) (
node *Node, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Create_Node(ctx, node_id, node_name, node_tag, node_public_address, node_api_secret, node_logo)
return tx.Create_Node(ctx, node_id, node_name, node_public_address, node_api_secret)
}
@ -1440,7 +1472,21 @@ func (rx *Rx) Update_Member_By_Id(ctx context.Context,
return tx.Update_Member_By_Id(ctx, member_id, update)
}
func (rx *Rx) Update_Node_By_Id(ctx context.Context,
node_id Node_Id_Field,
update Node_Update_Fields) (
node *Node, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Update_Node_By_Id(ctx, node_id, update)
}
type Methods interface {
All_Node(ctx context.Context) (
rows []*Node, err error)
Create_Member(ctx context.Context,
member_id Member_Id_Field,
member_email Member_Email_Field,
@ -1451,10 +1497,8 @@ type Methods interface {
Create_Node(ctx context.Context,
node_id Node_Id_Field,
node_name Node_Name_Field,
node_tag Node_Tag_Field,
node_public_address Node_PublicAddress_Field,
node_api_secret Node_ApiSecret_Field,
node_logo Node_Logo_Field) (
node_api_secret Node_ApiSecret_Field) (
node *Node, err error)
Delete_Member_By_Id(ctx context.Context,
@ -1481,6 +1525,11 @@ type Methods interface {
member_id Member_Id_Field,
update Member_Update_Fields) (
member *Member, err error)
Update_Node_By_Id(ctx context.Context,
node_id Node_Id_Field,
update Node_Update_Fields) (
node *Node, err error)
}
type TxMethods interface {

View File

@ -11,9 +11,7 @@ CREATE TABLE members (
CREATE TABLE nodes (
id bytea NOT NULL,
name text NOT NULL,
tag text NOT NULL,
public_address text NOT NULL,
api_secret bytea NOT NULL,
logo bytea NOT NULL,
PRIMARY KEY ( id )
);

View File

@ -5,6 +5,8 @@ package multinodedb
import (
"context"
"database/sql"
"errors"
"github.com/zeebo/errs"
@ -27,6 +29,47 @@ type nodes struct {
methods dbx.Methods
}
// GetAll returns all connected nodes.
func (n *nodes) GetAll(ctx context.Context) (nodes []console.Node, err error) {
defer mon.Task()(&ctx)(&err)
dbxNodes, err := n.methods.All_Node(ctx)
if err != nil {
return []console.Node{}, NodesDBError.Wrap(err)
}
if len(dbxNodes) == 0 {
return []console.Node{}, console.ErrNoNode.New("no nodes")
}
for _, dbxNode := range dbxNodes {
node, err := fromDBXNode(ctx, dbxNode)
if err != nil {
return []console.Node{}, NodesDBError.Wrap(err)
}
nodes = append(nodes, node)
}
return nodes, NodesDBError.Wrap(err)
}
// GetByID return node from NodesDB by its id.
func (n *nodes) GetByID(ctx context.Context, id storj.NodeID) (_ console.Node, err error) {
defer mon.Task()(&ctx)(&err)
dbxNode, err := n.methods.Get_Node_By_Id(ctx, dbx.Node_Id(id.Bytes()))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return console.Node{}, console.ErrNoNode.Wrap(err)
}
return console.Node{}, NodesDBError.Wrap(err)
}
node, err := fromDBXNode(ctx, dbxNode)
return node, NodesDBError.Wrap(err)
}
// Add creates new node in NodesDB.
func (n *nodes) Add(ctx context.Context, id storj.NodeID, apiSecret []byte, publicAddress string) (err error) {
defer mon.Task()(&ctx)(&err)
@ -35,10 +78,8 @@ func (n *nodes) Add(ctx context.Context, id storj.NodeID, apiSecret []byte, publ
ctx,
dbx.Node_Id(id.Bytes()),
dbx.Node_Name(""),
dbx.Node_Tag(""),
dbx.Node_PublicAddress(publicAddress),
dbx.Node_ApiSecret(apiSecret),
dbx.Node_Logo(nil),
)
return NodesDBError.Wrap(err)
@ -53,20 +94,6 @@ func (n *nodes) Remove(ctx context.Context, id storj.NodeID) (err error) {
return NodesDBError.Wrap(err)
}
// GetByID return node from NodesDB by its id.
func (n *nodes) GetByID(ctx context.Context, id storj.NodeID) (_ console.Node, err error) {
defer mon.Task()(&ctx)(&err)
dbxNode, err := n.methods.Get_Node_By_Id(ctx, dbx.Node_Id(id.Bytes()))
if err != nil {
return console.Node{}, NodesDBError.Wrap(err)
}
node, err := fromDBXNode(ctx, dbxNode)
return node, NodesDBError.Wrap(err)
}
// fromDBXNode converts dbx.Node to console.Node.
func fromDBXNode(ctx context.Context, node *dbx.Node) (_ console.Node, err error) {
defer mon.Task()(&ctx)(&err)
@ -79,9 +106,8 @@ func fromDBXNode(ctx context.Context, node *dbx.Node) (_ console.Node, err error
result := console.Node{
ID: id,
APISecret: node.ApiSecret,
Name: node.Name,
PublicAddress: node.PublicAddress,
Logo: node.Logo,
Tag: node.Tag,
}
return result, nil