storj/satellite/satellitedb/bandwidthagreement.go

108 lines
3.6 KiB
Go
Raw Normal View History

2019-01-24 20:15:10 +00:00
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb
import (
"context"
"fmt"
"time"
"github.com/zeebo/errs"
"storj.io/storj/pkg/bwagreement"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
type bandwidthagreement struct {
db *dbx.DB
}
func (b *bandwidthagreement) CreateAgreement(ctx context.Context, rba *pb.RenterBandwidthAllocation) (err error) {
expiration := time.Unix(rba.PayerAllocation.ExpirationUnixSec, 0)
_, err = b.db.Create_Bwagreement(
ctx,
dbx.Bwagreement_Serialnum(rba.PayerAllocation.SerialNumber+rba.StorageNodeId.String()),
dbx.Bwagreement_StorageNodeId(rba.StorageNodeId.Bytes()),
dbx.Bwagreement_UplinkId(rba.PayerAllocation.UplinkId.Bytes()),
2019-01-28 21:16:21 +00:00
dbx.Bwagreement_Action(int64(rba.PayerAllocation.Action)),
dbx.Bwagreement_Total(rba.Total),
dbx.Bwagreement_ExpiresAt(expiration),
)
return err
}
//GetTotals returns stats about an uplink
func (b *bandwidthagreement) GetUplinkStats(ctx context.Context, from, to time.Time) (stats []bwagreement.UplinkStat, err error) {
var uplinkSQL = fmt.Sprintf(`SELECT uplink_id, SUM(total),
COUNT(CASE WHEN action = %d THEN total ELSE null END),
COUNT(CASE WHEN action = %d THEN total ELSE null END), COUNT(*)
FROM bwagreements WHERE created_at > ?
AND created_at <= ? GROUP BY uplink_id ORDER BY uplink_id`,
pb.BandwidthAction_PUT, pb.BandwidthAction_GET)
rows, err := b.db.DB.Query(b.db.Rebind(uplinkSQL), from, to)
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var nodeID []byte
stat := bwagreement.UplinkStat{}
err := rows.Scan(&nodeID, &stat.TotalBytes, &stat.PutActionCount, &stat.GetActionCount, &stat.TotalTransactions)
if err != nil {
return stats, err
}
id, err := storj.NodeIDFromBytes(nodeID)
if err != nil {
return stats, err
}
stat.NodeID = id
stats = append(stats, stat)
}
return stats, nil
}
//GetTotals returns the sum of each bandwidth type after (exluding) a given date range
func (b *bandwidthagreement) GetTotals(ctx context.Context, from, to time.Time) (bwa map[storj.NodeID][]int64, err error) {
var getTotalsSQL = fmt.Sprintf(`SELECT storage_node_id,
SUM(CASE WHEN action = %d THEN total ELSE 0 END),
SUM(CASE WHEN action = %d THEN total ELSE 0 END),
SUM(CASE WHEN action = %d THEN total ELSE 0 END),
SUM(CASE WHEN action = %d THEN total ELSE 0 END),
SUM(CASE WHEN action = %d THEN total ELSE 0 END)
FROM bwagreements WHERE created_at > ? AND created_at <= ?
GROUP BY storage_node_id ORDER BY storage_node_id`, pb.BandwidthAction_PUT,
pb.BandwidthAction_GET, pb.BandwidthAction_GET_AUDIT,
pb.BandwidthAction_GET_REPAIR, pb.BandwidthAction_PUT_REPAIR)
rows, err := b.db.DB.Query(b.db.Rebind(getTotalsSQL), from, to)
if err != nil {
return nil, err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
totals := make(map[storj.NodeID][]int64)
for i := 0; rows.Next(); i++ {
var nodeID []byte
data := make([]int64, len(pb.BandwidthAction_value))
err := rows.Scan(&nodeID, &data[pb.BandwidthAction_PUT], &data[pb.BandwidthAction_GET],
&data[pb.BandwidthAction_GET_AUDIT], &data[pb.BandwidthAction_GET_REPAIR], &data[pb.BandwidthAction_PUT_REPAIR])
if err != nil {
return totals, err
}
id, err := storj.NodeIDFromBytes(nodeID)
if err != nil {
return totals, err
}
totals[id] = data
}
return totals, nil
}
func (b *bandwidthagreement) DeletePaidAndExpired(ctx context.Context) error {
// TODO: implement deletion of paid and expired BWAs
return Error.New("DeletePaidAndExpired not implemented")
}