better way to check if err is ErrorCode_NoRows (#1453)

* better way to check if err is ErrorCode_NoRows
This commit is contained in:
paul cannon 2019-03-18 19:15:27 -05:00 committed by Bill Thorp
parent 7961bcbc92
commit cd91a22e0f
3 changed files with 79 additions and 12 deletions

View File

@ -139,11 +139,18 @@ create node ( )
update node ( where node.id = ? )
delete node ( where node.id = ? )
// "Get" query; fails if node not found
read one (
select node
where node.id = ?
)
// "Find" query; returns nil if node not found
read scalar (
select node
where node.id = ?
)
read all (
select node.id
)

View File

@ -3486,6 +3486,30 @@ func (obj *postgresImpl) Get_Node_By_Id(ctx context.Context,
}
func (obj *postgresImpl) Find_Node_By_Id(ctx context.Context,
node_id Node_Id_Field) (
node *Node, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT nodes.id, nodes.audit_success_count, nodes.total_audit_count, nodes.audit_success_ratio, nodes.uptime_success_count, nodes.total_uptime_count, nodes.uptime_ratio, nodes.created_at, nodes.updated_at, nodes.wallet, nodes.email FROM nodes WHERE nodes.id = ?")
var __values []interface{}
__values = append(__values, node_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
node = &Node{}
err = obj.driver.QueryRow(__stmt, __values...).Scan(&node.Id, &node.AuditSuccessCount, &node.TotalAuditCount, &node.AuditSuccessRatio, &node.UptimeSuccessCount, &node.TotalUptimeCount, &node.UptimeRatio, &node.CreatedAt, &node.UpdatedAt, &node.Wallet, &node.Email)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, obj.makeErr(err)
}
return node, nil
}
func (obj *postgresImpl) All_Node_Id(ctx context.Context) (
rows []*Id_Row, err error) {
@ -5641,6 +5665,30 @@ func (obj *sqlite3Impl) Get_Node_By_Id(ctx context.Context,
}
func (obj *sqlite3Impl) Find_Node_By_Id(ctx context.Context,
node_id Node_Id_Field) (
node *Node, err error) {
var __embed_stmt = __sqlbundle_Literal("SELECT nodes.id, nodes.audit_success_count, nodes.total_audit_count, nodes.audit_success_ratio, nodes.uptime_success_count, nodes.total_uptime_count, nodes.uptime_ratio, nodes.created_at, nodes.updated_at, nodes.wallet, nodes.email FROM nodes WHERE nodes.id = ?")
var __values []interface{}
__values = append(__values, node_id.value())
var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)
node = &Node{}
err = obj.driver.QueryRow(__stmt, __values...).Scan(&node.Id, &node.AuditSuccessCount, &node.TotalAuditCount, &node.AuditSuccessRatio, &node.UptimeSuccessCount, &node.TotalUptimeCount, &node.UptimeRatio, &node.CreatedAt, &node.UpdatedAt, &node.Wallet, &node.Email)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, obj.makeErr(err)
}
return node, nil
}
func (obj *sqlite3Impl) All_Node_Id(ctx context.Context) (
rows []*Id_Row, err error) {
@ -7901,6 +7949,16 @@ func (rx *Rx) Find_AccountingTimestamps_Value_By_Name(ctx context.Context,
return tx.Find_AccountingTimestamps_Value_By_Name(ctx, accounting_timestamps_name)
}
func (rx *Rx) Find_Node_By_Id(ctx context.Context,
node_id Node_Id_Field) (
node *Node, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Find_Node_By_Id(ctx, node_id)
}
func (rx *Rx) First_Injuredsegment(ctx context.Context) (
injuredsegment *Injuredsegment, err error) {
var tx *Tx
@ -8381,6 +8439,10 @@ type Methods interface {
accounting_timestamps_name AccountingTimestamps_Name_Field) (
row *Value_Row, err error)
Find_Node_By_Id(ctx context.Context,
node_id Node_Id_Field) (
node *Node, err error)
First_Injuredsegment(ctx context.Context) (
injuredsegment *Injuredsegment, err error)

View File

@ -22,6 +22,9 @@ var (
mon = monkit.Package()
errAuditSuccess = errs.Class("statdb audit success error")
errUptime = errs.Class("statdb uptime error")
// ErrNodeNotFound may be returned when a node is not found in the statdb.
ErrNodeNotFound = errs.New("statdb node not found")
)
// StatDB implements the statdb RPC service
@ -103,10 +106,13 @@ func (s *statDB) Create(ctx context.Context, nodeID storj.NodeID, startingStats
func (s *statDB) Get(ctx context.Context, nodeID storj.NodeID) (stats *statdb.NodeStats, err error) {
defer mon.Task()(&ctx)(&err)
dbNode, err := s.db.Get_Node_By_Id(ctx, dbx.Node_Id(nodeID.Bytes()))
dbNode, err := s.db.Find_Node_By_Id(ctx, dbx.Node_Id(nodeID.Bytes()))
if err != nil {
return nil, Error.Wrap(err)
}
if dbNode == nil {
return nil, ErrNodeNotFound
}
nodeStats := getNodeStats(nodeID, dbNode)
return nodeStats, nil
@ -354,18 +360,10 @@ func (s *statDB) CreateEntryIfNotExists(ctx context.Context, nodeID storj.NodeID
defer mon.Task()(&ctx)(&err)
getStats, err := s.Get(ctx, nodeID)
// TODO: figure out better way to confirm error is type dbx.ErrorCode_NoRows
if err != nil && strings.Contains(err.Error(), "no rows in result set") {
createStats, err := s.Create(ctx, nodeID, nil)
if err != nil {
return nil, err
}
return createStats, nil
if err == ErrNodeNotFound {
return s.Create(ctx, nodeID, nil)
}
if err != nil {
return nil, err
}
return getStats, nil
return getStats, err
}
func updateRatioVars(newStatus bool, successCount, totalCount int64) (int64, int64, float64) {