storj/pkg/piecestore/psserver/psdb/psdb.go
2019-04-10 11:26:12 -04:00

466 lines
13 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package psdb
import (
"database/sql"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/gogo/protobuf/proto"
_ "github.com/mattn/go-sqlite3" // register sqlite to sql
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/storj/internal/migrate"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
)
var (
// Error is the default psdb errs class
Error = errs.Class("psdb")
)
// AgreementStatus keep tracks of the agreement payout status
type AgreementStatus int32
const (
// AgreementStatusUnsent sets the agreement status to UNSENT
AgreementStatusUnsent = iota
// AgreementStatusSent sets the agreement status to SENT
AgreementStatusSent
// AgreementStatusReject sets the agreement status to REJECT
AgreementStatusReject
// add new status here ...
)
// DB is a piece store database
type DB struct {
mu sync.Mutex
db *sql.DB
dbPath string
}
// Agreement is a struct that contains a bandwidth agreement and the associated signature
type Agreement struct {
Agreement pb.Order
Signature []byte
}
// Open opens DB at DBPath
func Open(DBPath string) (db *DB, err error) {
if err = os.MkdirAll(filepath.Dir(DBPath), 0700); err != nil {
return nil, err
}
sqlite, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?%s", DBPath, "_journal=WAL"))
if err != nil {
return nil, Error.Wrap(err)
}
db = &DB{
db: sqlite,
dbPath: DBPath,
}
return db, nil
}
// OpenInMemory opens sqlite DB inmemory
func OpenInMemory() (db *DB, err error) {
sqlite, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
db = &DB{
db: sqlite,
}
return db, nil
}
// Migration define piecestore DB migration
func (db *DB) Migration() *migrate.Migration {
migration := &migrate.Migration{
Table: "versions",
Steps: []*migrate.Step{
{
Description: "Initial setup",
Version: 0,
Action: migrate.SQL{
`CREATE TABLE IF NOT EXISTS ttl (
id BLOB UNIQUE,
created INT(10),
expires INT(10),
size INT(10)
)`,
`CREATE TABLE IF NOT EXISTS bandwidth_agreements (
satellite BLOB,
agreement BLOB,
signature BLOB
)`,
`CREATE INDEX IF NOT EXISTS idx_ttl_expires ON ttl (
expires
)`,
`CREATE TABLE IF NOT EXISTS bwusagetbl (
size INT(10),
daystartdate INT(10),
dayenddate INT(10)
)`,
},
},
{
Description: "Extending bandwidth_agreements table and drop bwusagetbl",
Version: 1,
Action: migrate.Func(func(log *zap.Logger, db migrate.DB, tx *sql.Tx) error {
v1sql := migrate.SQL{
`ALTER TABLE bandwidth_agreements ADD COLUMN uplink BLOB`,
`ALTER TABLE bandwidth_agreements ADD COLUMN serial_num BLOB`,
`ALTER TABLE bandwidth_agreements ADD COLUMN total INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN max_size INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN created_utc_sec INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN expiration_utc_sec INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN action INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN daystart_utc_sec INT(10)`,
}
err := v1sql.Run(log, db, tx)
if err != nil {
return err
}
// iterate through the table and fill
err = func() error {
rows, err := tx.Query(`SELECT agreement, signature FROM bandwidth_agreements`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var rbaBytes, signature []byte
rba := &pb.RenterBandwidthAllocation{}
err := rows.Scan(&rbaBytes, &signature)
if err != nil {
return err
}
// unmarshal the rbaBytes
err = proto.Unmarshal(rbaBytes, rba)
if err != nil {
return err
}
// update the new columns data
t := time.Unix(rba.PayerAllocation.CreatedUnixSec, 0)
startofthedayUnixSec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC).Unix()
// update the row by signature as it is unique
_, err = tx.Exec(`UPDATE bandwidth_agreements SET
uplink = ?,
serial_num = ?,
total = ?,
max_size = ?,
created_utc_sec = ?,
expiration_utc_sec = ?,
action = ?,
daystart_utc_sec = ?
WHERE signature = ?
`,
rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber,
rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec,
rba.PayerAllocation.ExpirationUnixSec, rba.PayerAllocation.GetAction(),
startofthedayUnixSec, signature)
if err != nil {
return err
}
}
return rows.Err()
}()
if err != nil {
return err
}
_, err = tx.Exec(`DROP TABLE bwusagetbl;`)
if err != nil {
return err
}
return nil
}),
},
{
Description: "Add status column for bandwidth_agreements",
Version: 2,
Action: migrate.SQL{
`ALTER TABLE bandwidth_agreements ADD COLUMN status INT(10) DEFAULT 0`,
},
},
{
Description: "Add index on serial number for bandwidth_agreements",
Version: 3,
Action: migrate.SQL{
`CREATE INDEX IF NOT EXISTS idx_bwa_serial ON bandwidth_agreements (serial_num)`,
},
},
{
Description: "Initiate Network reset",
Version: 4,
Action: migrate.SQL{
`UPDATE ttl SET expires = 1553727600 WHERE created <= 1553727600 `,
},
},
{
Description: "delete obsolete pieces",
Version: 5,
Action: migrate.Func(func(log *zap.Logger, mdb migrate.DB, tx *sql.Tx) error {
path := db.dbPath
if path == "" {
log.Warn("Empty path")
return nil
}
err := db.DeleteObsolete(path)
if err != nil {
log.Warn("err deleting obsolete paths: ", zap.Error(err))
}
return nil
}),
},
},
}
return migration
}
// Close the database
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) locked() func() {
db.mu.Lock()
return db.mu.Unlock
}
// DeleteObsolete deletes obsolete pieces
func (db *DB) DeleteObsolete(path string) (err error) {
path = filepath.Dir(path)
files, err := ioutil.ReadDir(path)
if err != nil {
return err
}
var errList errs.Group
// iterate thru files list
for _, f := range files {
if len(f.Name()) == 2 {
errList.Add(os.RemoveAll(filepath.Join(path, f.Name())))
}
}
return errList.Err()
}
// WriteBandwidthAllocToDB inserts bandwidth agreement into DB
func (db *DB) WriteBandwidthAllocToDB(rba *pb.Order) error {
rbaBytes, err := proto.Marshal(rba)
if err != nil {
return err
}
defer db.locked()()
// We begin extracting the satellite_id
// The satellite id can be used to sort the bandwidth agreements
// If the agreements are sorted we can send them in bulk streams to the satellite
t := time.Now()
startofthedayunixsec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()).Unix()
_, err = db.db.Exec(`INSERT INTO bandwidth_agreements (satellite, agreement, signature, uplink, serial_num, total, max_size, created_utc_sec, status, expiration_utc_sec, action, daystart_utc_sec) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
rba.PayerAllocation.SatelliteId.Bytes(), rbaBytes, rba.GetSignature(),
rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber,
rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec, AgreementStatusUnsent,
rba.PayerAllocation.ExpirationUnixSec, rba.PayerAllocation.GetAction().String(),
startofthedayunixsec)
return err
}
// DeleteBandwidthAllocationPayouts delete paid and/or old payout enteries based on days old
func (db *DB) DeleteBandwidthAllocationPayouts() error {
defer db.locked()()
//@TODO make a config value for older days
t := time.Now().Add(time.Hour * 24 * -90).Unix()
_, err := db.db.Exec(`DELETE FROM bandwidth_agreements WHERE created_utc_sec < ?`, t)
if err == sql.ErrNoRows {
err = nil
}
return err
}
// UpdateBandwidthAllocationStatus update the bwa payout status
func (db *DB) UpdateBandwidthAllocationStatus(serialnum string, status AgreementStatus) (err error) {
defer db.locked()()
_, err = db.db.Exec(`UPDATE bandwidth_agreements SET status = ? WHERE serial_num = ?`, status, serialnum)
return err
}
// DeleteBandwidthAllocationBySerialnum finds an allocation by signature and deletes it
func (db *DB) DeleteBandwidthAllocationBySerialnum(serialnum string) error {
defer db.locked()()
_, err := db.db.Exec(`DELETE FROM bandwidth_agreements WHERE serial_num=?`, serialnum)
if err == sql.ErrNoRows {
err = nil
}
return err
}
// GetBandwidthAllocationBySignature finds allocation info by signature
func (db *DB) GetBandwidthAllocationBySignature(signature []byte) ([]*pb.Order, error) {
defer db.locked()()
rows, err := db.db.Query(`SELECT agreement FROM bandwidth_agreements WHERE signature = ?`, signature)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
zap.S().Errorf("failed to close rows when selecting from bandwidth_agreements: %+v", closeErr)
}
}()
agreements := []*pb.Order{}
for rows.Next() {
var rbaBytes []byte
err := rows.Scan(&rbaBytes)
if err != nil {
return agreements, err
}
rba := &pb.Order{}
err = proto.Unmarshal(rbaBytes, rba)
if err != nil {
return agreements, err
}
agreements = append(agreements, rba)
}
return agreements, nil
}
// GetBandwidthAllocations all bandwidth agreements
func (db *DB) GetBandwidthAllocations() (map[storj.NodeID][]*Agreement, error) {
defer db.locked()()
rows, err := db.db.Query(`SELECT satellite, agreement FROM bandwidth_agreements WHERE status = ?`, AgreementStatusUnsent)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
zap.S().Errorf("failed to close rows when selecting from bandwidth_agreements: %+v", closeErr)
}
}()
agreements := make(map[storj.NodeID][]*Agreement)
for rows.Next() {
rbaBytes := []byte{}
agreement := &Agreement{}
var satellite []byte
err := rows.Scan(&satellite, &rbaBytes)
if err != nil {
return agreements, err
}
err = proto.Unmarshal(rbaBytes, &agreement.Agreement)
if err != nil {
return agreements, err
}
satelliteID, err := storj.NodeIDFromBytes(satellite)
if err != nil {
return nil, err
}
agreements[satelliteID] = append(agreements[satelliteID], agreement)
}
return agreements, nil
}
// GetBwaStatusBySerialNum get BWA status by serial num
func (db *DB) GetBwaStatusBySerialNum(serialnum string) (status AgreementStatus, err error) {
defer db.locked()()
err = db.db.QueryRow(`SELECT status FROM bandwidth_agreements WHERE serial_num=?`, serialnum).Scan(&status)
return status, err
}
// AddTTL adds TTL into database by id
func (db *DB) AddTTL(id string, expiration, size int64) error {
defer db.locked()()
created := time.Now().Unix()
_, err := db.db.Exec("INSERT OR REPLACE INTO ttl (id, created, expires, size) VALUES (?, ?, ?, ?)", id, created, expiration, size)
return err
}
// GetTTLByID finds the TTL in the database by id and return it
func (db *DB) GetTTLByID(id string) (expiration int64, err error) {
defer db.locked()()
err = db.db.QueryRow(`SELECT expires FROM ttl WHERE id=?`, id).Scan(&expiration)
return expiration, err
}
// SumTTLSizes sums the size column on the ttl table
func (db *DB) SumTTLSizes() (int64, error) {
defer db.locked()()
var sum *int64
err := db.db.QueryRow(`SELECT SUM(size) FROM ttl;`).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
return *sum, err
}
// DeleteTTLByID finds the TTL in the database by id and delete it
func (db *DB) DeleteTTLByID(id string) error {
defer db.locked()()
_, err := db.db.Exec(`DELETE FROM ttl WHERE id=?`, id)
if err == sql.ErrNoRows {
err = nil
}
return err
}
// GetBandwidthUsedByDay finds the so far bw used by day and return it
func (db *DB) GetBandwidthUsedByDay(t time.Time) (size int64, err error) {
return db.GetTotalBandwidthBetween(t, t)
}
// GetTotalBandwidthBetween each row in the bwusagetbl contains the total bw used per day
func (db *DB) GetTotalBandwidthBetween(startdate time.Time, enddate time.Time) (int64, error) {
defer db.locked()()
startTimeUnix := time.Date(startdate.Year(), startdate.Month(), startdate.Day(), 0, 0, 0, 0, startdate.Location()).Unix()
endTimeUnix := time.Date(enddate.Year(), enddate.Month(), enddate.Day(), 24, 0, 0, 0, enddate.Location()).Unix()
defaultunixtime := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Now().Location()).Unix()
if (endTimeUnix < startTimeUnix) && (startTimeUnix > defaultunixtime || endTimeUnix > defaultunixtime) {
return 0, errors.New("Invalid date range")
}
var totalUsage *int64
err := db.db.QueryRow(`SELECT SUM(total) FROM bandwidth_agreements WHERE daystart_utc_sec BETWEEN ? AND ?`, startTimeUnix, endTimeUnix).Scan(&totalUsage)
if err == sql.ErrNoRows || totalUsage == nil {
return 0, nil
}
return *totalUsage, err
}
// RawDB returns access to the raw database, only for migration tests.
func (db *DB) RawDB() *sql.DB { return db.db }
// Begin begins transaction
func (db *DB) Begin() (*sql.Tx, error) { return db.db.Begin() }
// Rebind rebind parameters
func (db *DB) Rebind(s string) string { return s }
// Schema returns schema
func (db *DB) Schema() string { return "" }