storj/pkg/piecestore/psserver/psdb/psdb.go
Egon Elbre 3f3209c8d5
fixes to piecestore and psdb (#1380)
* replace direct reference with an interface in various places
* hide piecePath
* ensure psserver tests don't use path
* ensure psserver tests don't use sql queries directly
2019-03-01 07:46:16 +02:00

410 lines
11 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package psdb
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/gogo/protobuf/proto"
_ "github.com/mattn/go-sqlite3" // register sqlite to sql
"github.com/zeebo/errs"
"go.uber.org/zap"
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/internal/migrate"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
)
var (
mon = monkit.Package()
// Error is the default psdb errs class
Error = errs.Class("psdb")
)
// DB is a piece store database
type DB struct {
mu sync.Mutex
db *sql.DB
}
// Agreement is a struct that contains a bandwidth agreement and the associated signature
type Agreement struct {
Agreement pb.Order
Signature []byte
}
// Open opens DB at DBPath
func Open(DBPath string) (db *DB, err error) {
if err = os.MkdirAll(filepath.Dir(DBPath), 0700); err != nil {
return nil, err
}
sqlite, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?%s", DBPath, "_journal=WAL"))
if err != nil {
return nil, Error.Wrap(err)
}
db = &DB{
db: sqlite,
}
return db, nil
}
// OpenInMemory opens sqlite DB inmemory
func OpenInMemory() (db *DB, err error) {
sqlite, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
db = &DB{
db: sqlite,
}
return db, nil
}
// Migration define piecestore DB migration
func (db *DB) Migration() *migrate.Migration {
migration := &migrate.Migration{
Table: "versions",
Steps: []*migrate.Step{
{
Description: "Initial setup",
Version: 0,
Action: migrate.SQL{
`CREATE TABLE IF NOT EXISTS ttl (
id BLOB UNIQUE,
created INT(10),
expires INT(10),
size INT(10)
)`,
`CREATE TABLE IF NOT EXISTS bandwidth_agreements (
satellite BLOB,
agreement BLOB,
signature BLOB
)`,
`CREATE INDEX IF NOT EXISTS idx_ttl_expires ON ttl (
expires
)`,
`CREATE TABLE IF NOT EXISTS bwusagetbl (
size INT(10),
daystartdate INT(10),
dayenddate INT(10)
)`,
},
},
{
Description: "Extending bandwidth_agreements table and drop bwusagetbl",
Version: 1,
Action: migrate.Func(func(log *zap.Logger, db migrate.DB, tx *sql.Tx) error {
v1sql := migrate.SQL{
`ALTER TABLE bandwidth_agreements ADD COLUMN uplink BLOB`,
`ALTER TABLE bandwidth_agreements ADD COLUMN serial_num BLOB`,
`ALTER TABLE bandwidth_agreements ADD COLUMN total INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN max_size INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN created_utc_sec INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN expiration_utc_sec INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN action INT(10)`,
`ALTER TABLE bandwidth_agreements ADD COLUMN daystart_utc_sec INT(10)`,
}
err := v1sql.Run(log, db, tx)
if err != nil {
return err
}
// iterate through the table and fill
err = func() error {
rows, err := tx.Query(`SELECT agreement, signature FROM bandwidth_agreements`)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
for rows.Next() {
var rbaBytes, signature []byte
rba := &pb.RenterBandwidthAllocation{}
err := rows.Scan(&rbaBytes, &signature)
if err != nil {
return err
}
// unmarshal the rbaBytes
err = proto.Unmarshal(rbaBytes, rba)
if err != nil {
return err
}
// update the new columns data
t := time.Unix(rba.PayerAllocation.CreatedUnixSec, 0)
startofthedayUnixSec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC).Unix()
// update the row by signature as it is unique
_, err = tx.Exec(`UPDATE bandwidth_agreements SET
uplink = ?,
serial_num = ?,
total = ?,
max_size = ?,
created_utc_sec = ?,
expiration_utc_sec = ?,
action = ?,
daystart_utc_sec = ?
WHERE signature = ?
`,
rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber,
rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec,
rba.PayerAllocation.ExpirationUnixSec, rba.PayerAllocation.GetAction(),
startofthedayUnixSec, signature)
if err != nil {
return err
}
}
return rows.Err()
}()
if err != nil {
return err
}
_, err = tx.Exec(`DROP TABLE bwusagetbl;`)
if err != nil {
return err
}
return nil
}),
},
},
}
return migration
}
// Close the database
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) locked() func() {
db.mu.Lock()
return db.mu.Unlock
}
// DeleteExpired deletes expired pieces
func (db *DB) DeleteExpired(ctx context.Context) (expired []string, err error) {
defer mon.Task()(&ctx)(&err)
defer db.locked()()
// TODO: add limit
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() { _ = tx.Rollback() }()
now := time.Now().Unix()
rows, err := tx.Query("SELECT id FROM ttl WHERE expires > 0 AND expires < ?", now)
if err != nil {
return nil, err
}
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
expired = append(expired, id)
}
if err := rows.Close(); err != nil {
return nil, err
}
_, err = tx.Exec(`DELETE FROM ttl WHERE expires > 0 AND expires < ?`, now)
if err != nil {
return nil, err
}
return expired, tx.Commit()
}
// WriteBandwidthAllocToDB inserts bandwidth agreement into DB
func (db *DB) WriteBandwidthAllocToDB(rba *pb.Order) error {
rbaBytes, err := proto.Marshal(rba)
if err != nil {
return err
}
defer db.locked()()
// We begin extracting the satellite_id
// The satellite id can be used to sort the bandwidth agreements
// If the agreements are sorted we can send them in bulk streams to the satellite
t := time.Now()
startofthedayunixsec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()).Unix()
_, err = db.db.Exec(`INSERT INTO bandwidth_agreements (satellite, agreement, signature, uplink, serial_num, total, max_size, created_utc_sec, expiration_utc_sec, action, daystart_utc_sec) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
rba.PayerAllocation.SatelliteId.Bytes(), rbaBytes, rba.GetSignature(),
rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber,
rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec,
rba.PayerAllocation.ExpirationUnixSec, rba.PayerAllocation.GetAction().String(),
startofthedayunixsec)
return err
}
// DeleteBandwidthAllocationBySerialnum finds an allocation by signature and deletes it
func (db *DB) DeleteBandwidthAllocationBySerialnum(serialnum string) error {
defer db.locked()()
_, err := db.db.Exec(`DELETE FROM bandwidth_agreements WHERE serial_num=?`, serialnum)
if err == sql.ErrNoRows {
err = nil
}
return err
}
// GetBandwidthAllocationBySignature finds allocation info by signature
func (db *DB) GetBandwidthAllocationBySignature(signature []byte) ([]*pb.Order, error) {
defer db.locked()()
rows, err := db.db.Query(`SELECT agreement FROM bandwidth_agreements WHERE signature = ?`, signature)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
zap.S().Errorf("failed to close rows when selecting from bandwidth_agreements: %+v", closeErr)
}
}()
agreements := []*pb.Order{}
for rows.Next() {
var rbaBytes []byte
err := rows.Scan(&rbaBytes)
if err != nil {
return agreements, err
}
rba := &pb.Order{}
err = proto.Unmarshal(rbaBytes, rba)
if err != nil {
return agreements, err
}
agreements = append(agreements, rba)
}
return agreements, nil
}
// GetBandwidthAllocations all bandwidth agreements and sorts by satellite
func (db *DB) GetBandwidthAllocations() (map[storj.NodeID][]*Agreement, error) {
defer db.locked()()
rows, err := db.db.Query(`SELECT satellite, agreement FROM bandwidth_agreements`)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
zap.S().Errorf("failed to close rows when selecting from bandwidth_agreements: %+v", closeErr)
}
}()
agreements := make(map[storj.NodeID][]*Agreement)
for rows.Next() {
rbaBytes := []byte{}
agreement := &Agreement{}
var satellite []byte
err := rows.Scan(&satellite, &rbaBytes)
if err != nil {
return agreements, err
}
err = proto.Unmarshal(rbaBytes, &agreement.Agreement)
if err != nil {
return agreements, err
}
satelliteID, err := storj.NodeIDFromBytes(satellite)
if err != nil {
return nil, err
}
agreements[satelliteID] = append(agreements[satelliteID], agreement)
}
return agreements, nil
}
// AddTTL adds TTL into database by id
func (db *DB) AddTTL(id string, expiration, size int64) error {
defer db.locked()()
created := time.Now().Unix()
_, err := db.db.Exec("INSERT OR REPLACE INTO ttl (id, created, expires, size) VALUES (?, ?, ?, ?)", id, created, expiration, size)
return err
}
// GetTTLByID finds the TTL in the database by id and return it
func (db *DB) GetTTLByID(id string) (expiration int64, err error) {
defer db.locked()()
err = db.db.QueryRow(`SELECT expires FROM ttl WHERE id=?`, id).Scan(&expiration)
return expiration, err
}
// SumTTLSizes sums the size column on the ttl table
func (db *DB) SumTTLSizes() (int64, error) {
defer db.locked()()
var sum *int64
err := db.db.QueryRow(`SELECT SUM(size) FROM ttl;`).Scan(&sum)
if err == sql.ErrNoRows || sum == nil {
return 0, nil
}
return *sum, err
}
// DeleteTTLByID finds the TTL in the database by id and delete it
func (db *DB) DeleteTTLByID(id string) error {
defer db.locked()()
_, err := db.db.Exec(`DELETE FROM ttl WHERE id=?`, id)
if err == sql.ErrNoRows {
err = nil
}
return err
}
// GetBandwidthUsedByDay finds the so far bw used by day and return it
func (db *DB) GetBandwidthUsedByDay(t time.Time) (size int64, err error) {
return db.GetTotalBandwidthBetween(t, t)
}
// GetTotalBandwidthBetween each row in the bwusagetbl contains the total bw used per day
func (db *DB) GetTotalBandwidthBetween(startdate time.Time, enddate time.Time) (int64, error) {
defer db.locked()()
startTimeUnix := time.Date(startdate.Year(), startdate.Month(), startdate.Day(), 0, 0, 0, 0, startdate.Location()).Unix()
endTimeUnix := time.Date(enddate.Year(), enddate.Month(), enddate.Day(), 24, 0, 0, 0, enddate.Location()).Unix()
defaultunixtime := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Now().Location()).Unix()
if (endTimeUnix < startTimeUnix) && (startTimeUnix > defaultunixtime || endTimeUnix > defaultunixtime) {
return 0, errors.New("Invalid date range")
}
var totalUsage *int64
err := db.db.QueryRow(`SELECT SUM(total) FROM bandwidth_agreements WHERE daystart_utc_sec BETWEEN ? AND ?`, startTimeUnix, endTimeUnix).Scan(&totalUsage)
if err == sql.ErrNoRows || totalUsage == nil {
return 0, nil
}
return *totalUsage, err
}
// RawDB returns access to the raw database, only for migration tests.
func (db *DB) RawDB() *sql.DB { return db.db }
// Begin begins transaction
func (db *DB) Begin() (*sql.Tx, error) { return db.db.Begin() }
// Rebind rebind parameters
func (db *DB) Rebind(s string) string { return s }
// Schema returns schema
func (db *DB) Schema() string { return "" }