fixes to piecestore and psdb (#1380)

* replace direct reference with an interface in various places
* hide piecePath
* ensure psserver tests don't use path
* ensure psserver tests don't use sql queries directly
This commit is contained in:
Egon Elbre 2019-03-01 07:46:16 +02:00 committed by GitHub
parent 2efaa27318
commit 3f3209c8d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 261 additions and 310 deletions

View File

@ -10,7 +10,6 @@ import (
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
pstore "storj.io/storj/pkg/piecestore"
"storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/piecestore/psserver/psdb"
) )
@ -21,13 +20,13 @@ var ErrorCollector = errs.Class("piecestore collector")
type Collector struct { type Collector struct {
log *zap.Logger log *zap.Logger
db *psdb.DB db *psdb.DB
storage *pstore.Storage storage Storage
interval time.Duration interval time.Duration
} }
// NewCollector returns a new piece collector // NewCollector returns a new piece collector
func NewCollector(log *zap.Logger, db *psdb.DB, storage *pstore.Storage, interval time.Duration) *Collector { func NewCollector(log *zap.Logger, db *psdb.DB, storage Storage, interval time.Duration) *Collector {
return &Collector{ return &Collector{
log: log, log: log,
db: db, db: db,

View File

@ -95,7 +95,7 @@ func TestMigrate(t *testing.T) {
defer func() { require.NoError(t, db.Close()) }() defer func() { require.NoError(t, db.Close()) }()
// insert the base data into sqlite // insert the base data into sqlite
_, err = db.DB.Exec(base.Script) _, err = db.RawDB().Exec(base.Script)
require.NoError(t, err) require.NoError(t, err)
// get migration for this database // get migration for this database
@ -118,19 +118,19 @@ func TestMigrate(t *testing.T) {
// insert data for new tables // insert data for new tables
if newdata := newData(expected); newdata != "" && step.Version > base.Version { if newdata := newData(expected); newdata != "" && step.Version > base.Version {
_, err = db.DB.Exec(newdata) _, err = db.RawDB().Exec(newdata)
require.NoError(t, err, tag) require.NoError(t, err, tag)
} }
// load schema from database // load schema from database
currentSchema, err := sqliteutil.QuerySchema(db.DB) currentSchema, err := sqliteutil.QuerySchema(db.RawDB())
require.NoError(t, err, tag) require.NoError(t, err, tag)
// we don't care changes in versions table // we don't care changes in versions table
currentSchema.DropTable("versions") currentSchema.DropTable("versions")
// load data from database // load data from database
currentData, err := sqliteutil.QueryData(db.DB, currentSchema) currentData, err := sqliteutil.QueryData(db.RawDB(), currentSchema)
require.NoError(t, err, tag) require.NoError(t, err, tag)
// verify schema and data // verify schema and data

View File

@ -33,7 +33,7 @@ var (
// DB is a piece store database // DB is a piece store database
type DB struct { type DB struct {
mu sync.Mutex mu sync.Mutex
DB *sql.DB // TODO: hide db *sql.DB
} }
// Agreement is a struct that contains a bandwidth agreement and the associated signature // Agreement is a struct that contains a bandwidth agreement and the associated signature
@ -53,7 +53,7 @@ func Open(DBPath string) (db *DB, err error) {
return nil, Error.Wrap(err) return nil, Error.Wrap(err)
} }
db = &DB{ db = &DB{
DB: sqlite, db: sqlite,
} }
return db, nil return db, nil
@ -67,7 +67,7 @@ func OpenInMemory() (db *DB, err error) {
} }
db = &DB{ db = &DB{
DB: sqlite, db: sqlite,
} }
return db, nil return db, nil
@ -185,7 +185,7 @@ func (db *DB) Migration() *migrate.Migration {
// Close the database // Close the database
func (db *DB) Close() error { func (db *DB) Close() error {
return db.DB.Close() return db.db.Close()
} }
func (db *DB) locked() func() { func (db *DB) locked() func() {
@ -200,7 +200,7 @@ func (db *DB) DeleteExpired(ctx context.Context) (expired []string, err error) {
// TODO: add limit // TODO: add limit
tx, err := db.DB.BeginTx(ctx, nil) tx, err := db.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -245,7 +245,7 @@ func (db *DB) WriteBandwidthAllocToDB(rba *pb.Order) error {
// If the agreements are sorted we can send them in bulk streams to the satellite // If the agreements are sorted we can send them in bulk streams to the satellite
t := time.Now() t := time.Now()
startofthedayunixsec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()).Unix() startofthedayunixsec := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()).Unix()
_, err = db.DB.Exec(`INSERT INTO bandwidth_agreements (satellite, agreement, signature, uplink, serial_num, total, max_size, created_utc_sec, expiration_utc_sec, action, daystart_utc_sec) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, _, err = db.db.Exec(`INSERT INTO bandwidth_agreements (satellite, agreement, signature, uplink, serial_num, total, max_size, created_utc_sec, expiration_utc_sec, action, daystart_utc_sec) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
rba.PayerAllocation.SatelliteId.Bytes(), rbaBytes, rba.GetSignature(), rba.PayerAllocation.SatelliteId.Bytes(), rbaBytes, rba.GetSignature(),
rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber, rba.PayerAllocation.UplinkId.Bytes(), rba.PayerAllocation.SerialNumber,
rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec, rba.Total, rba.PayerAllocation.MaxSize, rba.PayerAllocation.CreatedUnixSec,
@ -257,7 +257,7 @@ func (db *DB) WriteBandwidthAllocToDB(rba *pb.Order) error {
// DeleteBandwidthAllocationBySerialnum finds an allocation by signature and deletes it // DeleteBandwidthAllocationBySerialnum finds an allocation by signature and deletes it
func (db *DB) DeleteBandwidthAllocationBySerialnum(serialnum string) error { func (db *DB) DeleteBandwidthAllocationBySerialnum(serialnum string) error {
defer db.locked()() defer db.locked()()
_, err := db.DB.Exec(`DELETE FROM bandwidth_agreements WHERE serial_num=?`, serialnum) _, err := db.db.Exec(`DELETE FROM bandwidth_agreements WHERE serial_num=?`, serialnum)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -268,7 +268,7 @@ func (db *DB) DeleteBandwidthAllocationBySerialnum(serialnum string) error {
func (db *DB) GetBandwidthAllocationBySignature(signature []byte) ([]*pb.Order, error) { func (db *DB) GetBandwidthAllocationBySignature(signature []byte) ([]*pb.Order, error) {
defer db.locked()() defer db.locked()()
rows, err := db.DB.Query(`SELECT agreement FROM bandwidth_agreements WHERE signature = ?`, signature) rows, err := db.db.Query(`SELECT agreement FROM bandwidth_agreements WHERE signature = ?`, signature)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,7 +299,7 @@ func (db *DB) GetBandwidthAllocationBySignature(signature []byte) ([]*pb.Order,
func (db *DB) GetBandwidthAllocations() (map[storj.NodeID][]*Agreement, error) { func (db *DB) GetBandwidthAllocations() (map[storj.NodeID][]*Agreement, error) {
defer db.locked()() defer db.locked()()
rows, err := db.DB.Query(`SELECT satellite, agreement FROM bandwidth_agreements`) rows, err := db.db.Query(`SELECT satellite, agreement FROM bandwidth_agreements`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -336,7 +336,7 @@ func (db *DB) AddTTL(id string, expiration, size int64) error {
defer db.locked()() defer db.locked()()
created := time.Now().Unix() created := time.Now().Unix()
_, err := db.DB.Exec("INSERT OR REPLACE INTO ttl (id, created, expires, size) VALUES (?, ?, ?, ?)", id, created, expiration, size) _, err := db.db.Exec("INSERT OR REPLACE INTO ttl (id, created, expires, size) VALUES (?, ?, ?, ?)", id, created, expiration, size)
return err return err
} }
@ -344,34 +344,27 @@ func (db *DB) AddTTL(id string, expiration, size int64) error {
func (db *DB) GetTTLByID(id string) (expiration int64, err error) { func (db *DB) GetTTLByID(id string) (expiration int64, err error) {
defer db.locked()() defer db.locked()()
err = db.DB.QueryRow(`SELECT expires FROM ttl WHERE id=?`, id).Scan(&expiration) err = db.db.QueryRow(`SELECT expires FROM ttl WHERE id=?`, id).Scan(&expiration)
return expiration, err return expiration, err
} }
// SumTTLSizes sums the size column on the ttl table // SumTTLSizes sums the size column on the ttl table
func (db *DB) SumTTLSizes() (sum int64, err error) { func (db *DB) SumTTLSizes() (int64, error) {
defer db.locked()() defer db.locked()()
var count int var sum *int64
rows := db.DB.QueryRow("SELECT COUNT(*) as count FROM ttl") err := db.db.QueryRow(`SELECT SUM(size) FROM ttl;`).Scan(&sum)
err = rows.Scan(&count) if err == sql.ErrNoRows || sum == nil {
if err != nil {
return 0, err
}
if count == 0 {
return 0, nil return 0, nil
} }
return *sum, err
err = db.DB.QueryRow(`SELECT SUM(size) FROM ttl;`).Scan(&sum)
return sum, err
} }
// DeleteTTLByID finds the TTL in the database by id and delete it // DeleteTTLByID finds the TTL in the database by id and delete it
func (db *DB) DeleteTTLByID(id string) error { func (db *DB) DeleteTTLByID(id string) error {
defer db.locked()() defer db.locked()()
_, err := db.DB.Exec(`DELETE FROM ttl WHERE id=?`, id) _, err := db.db.Exec(`DELETE FROM ttl WHERE id=?`, id)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -384,7 +377,7 @@ func (db *DB) GetBandwidthUsedByDay(t time.Time) (size int64, err error) {
} }
// GetTotalBandwidthBetween each row in the bwusagetbl contains the total bw used per day // GetTotalBandwidthBetween each row in the bwusagetbl contains the total bw used per day
func (db *DB) GetTotalBandwidthBetween(startdate time.Time, enddate time.Time) (totalbwusage int64, err error) { func (db *DB) GetTotalBandwidthBetween(startdate time.Time, enddate time.Time) (int64, error) {
defer db.locked()() defer db.locked()()
startTimeUnix := time.Date(startdate.Year(), startdate.Month(), startdate.Day(), 0, 0, 0, 0, startdate.Location()).Unix() startTimeUnix := time.Date(startdate.Year(), startdate.Month(), startdate.Day(), 0, 0, 0, 0, startdate.Location()).Unix()
@ -392,26 +385,22 @@ func (db *DB) GetTotalBandwidthBetween(startdate time.Time, enddate time.Time) (
defaultunixtime := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Now().Location()).Unix() defaultunixtime := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Now().Location()).Unix()
if (endTimeUnix < startTimeUnix) && (startTimeUnix > defaultunixtime || endTimeUnix > defaultunixtime) { if (endTimeUnix < startTimeUnix) && (startTimeUnix > defaultunixtime || endTimeUnix > defaultunixtime) {
return totalbwusage, errors.New("Invalid date range") return 0, errors.New("Invalid date range")
} }
var count int var totalUsage *int64
rows := db.DB.QueryRow("SELECT COUNT(*) as count FROM bandwidth_agreements") err := db.db.QueryRow(`SELECT SUM(total) FROM bandwidth_agreements WHERE daystart_utc_sec BETWEEN ? AND ?`, startTimeUnix, endTimeUnix).Scan(&totalUsage)
err = rows.Scan(&count) if err == sql.ErrNoRows || totalUsage == nil {
if err != nil {
return 0, err
}
if count == 0 {
return 0, nil return 0, nil
} }
return *totalUsage, err
err = db.DB.QueryRow(`SELECT SUM(total) FROM bandwidth_agreements WHERE daystart_utc_sec BETWEEN ? AND ?`, startTimeUnix, endTimeUnix).Scan(&totalbwusage)
return totalbwusage, err
} }
// RawDB returns access to the raw database, only for migration tests.
func (db *DB) RawDB() *sql.DB { return db.db }
// Begin begins transaction // Begin begins transaction
func (db *DB) Begin() (*sql.Tx, error) { return db.DB.Begin() } func (db *DB) Begin() (*sql.Tx, error) { return db.db.Begin() }
// Rebind rebind parameters // Rebind rebind parameters
func (db *DB) Rebind(s string) string { return s } func (db *DB) Rebind(s string) string { return s }

View File

@ -12,7 +12,7 @@ import (
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap/zaptest"
"storj.io/storj/internal/teststorj" "storj.io/storj/internal/teststorj"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
@ -31,7 +31,7 @@ func newDB(t testing.TB, id string) (*psdb.DB, func()) {
db, err := psdb.Open(dbpath) db, err := psdb.Open(dbpath)
require.NoError(t, err) require.NoError(t, err)
err = db.Migration().Run(zap.NewNop(), db) err = db.Migration().Run(zaptest.NewLogger(t), db)
require.NoError(t, err) require.NoError(t, err)
return db, func() { return db, func() {
@ -67,6 +67,66 @@ func TestHappyPath(t *testing.T) {
{ID: "test", Expiration: 666}, {ID: "test", Expiration: 666},
} }
bandwidthAllocation := func(signature string, satelliteID storj.NodeID, total int64) *pb.Order {
return &pb.Order{
PayerAllocation: pb.OrderLimit{SatelliteId: satelliteID},
Total: total,
Signature: []byte(signature),
}
}
//TODO: use better data
nodeIDAB := teststorj.NodeIDFromString("AB")
allocationTests := []*pb.Order{
bandwidthAllocation("signed by test", nodeIDAB, 0),
bandwidthAllocation("signed by sigma", nodeIDAB, 10),
bandwidthAllocation("signed by sigma", nodeIDAB, 98),
bandwidthAllocation("signed by test", nodeIDAB, 3),
}
type bwUsage struct {
size int64
timenow time.Time
}
bwtests := []bwUsage{
// size is total size stored
{size: 1110, timenow: time.Now()},
}
t.Run("Empty", func(t *testing.T) {
t.Run("Bandwidth Allocation", func(t *testing.T) {
for _, test := range allocationTests {
agreements, err := db.GetBandwidthAllocationBySignature(test.Signature)
require.Len(t, agreements, 0)
require.NoError(t, err)
}
})
t.Run("Get all Bandwidth Allocations", func(t *testing.T) {
agreementGroups, err := db.GetBandwidthAllocations()
require.Len(t, agreementGroups, 0)
require.NoError(t, err)
})
t.Run("GetBandwidthUsedByDay", func(t *testing.T) {
for _, bw := range bwtests {
size, err := db.GetBandwidthUsedByDay(bw.timenow)
require.NoError(t, err)
require.Equal(t, int64(0), size)
}
})
t.Run("GetTotalBandwidthBetween", func(t *testing.T) {
for _, bw := range bwtests {
size, err := db.GetTotalBandwidthBetween(bw.timenow, bw.timenow)
require.NoError(t, err)
require.Equal(t, int64(0), size)
}
})
})
t.Run("Create", func(t *testing.T) { t.Run("Create", func(t *testing.T) {
for P := 0; P < concurrency; P++ { for P := 0; P < concurrency; P++ {
t.Run("#"+strconv.Itoa(P), func(t *testing.T) { t.Run("#"+strconv.Itoa(P), func(t *testing.T) {
@ -130,35 +190,7 @@ func TestHappyPath(t *testing.T) {
} }
}) })
bandwidthAllocation := func(signature string, satelliteID storj.NodeID, total int64) *pb.Order {
return &pb.Order{
PayerAllocation: pb.OrderLimit{SatelliteId: satelliteID},
Total: total,
Signature: []byte(signature),
}
}
//TODO: use better data
nodeIDAB := teststorj.NodeIDFromString("AB")
allocationTests := []*pb.Order{
bandwidthAllocation("signed by test", nodeIDAB, 0),
bandwidthAllocation("signed by sigma", nodeIDAB, 10),
bandwidthAllocation("signed by sigma", nodeIDAB, 98),
bandwidthAllocation("signed by test", nodeIDAB, 3),
}
type bwUsage struct {
size int64
timenow time.Time
}
bwtests := []bwUsage{
// size is total size stored
{size: 1110, timenow: time.Now()},
}
t.Run("Bandwidth Allocation", func(t *testing.T) { t.Run("Bandwidth Allocation", func(t *testing.T) {
for P := 0; P < concurrency; P++ { for P := 0; P < concurrency; P++ {
t.Run("#"+strconv.Itoa(P), func(t *testing.T) { t.Run("#"+strconv.Itoa(P), func(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -7,7 +7,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"os"
"sync/atomic" "sync/atomic"
"github.com/zeebo/errs" "github.com/zeebo/errs"
@ -62,20 +61,13 @@ func (s *Server) Retrieve(stream pb.PieceStoreRoutes_RetrieveServer) (err error)
) )
// Get path to data being retrieved // Get path to data being retrieved
path, err := s.storage.PiecePath(id) fileSize, err := s.storage.Size(id)
if err != nil { if err != nil {
return err return err
} }
// Verify that the path exists
fileInfo, err := os.Stat(path)
if err != nil {
return RetrieveError.Wrap(err)
}
// Read the size specified // Read the size specified
totalToRead := pd.GetPieceSize() totalToRead := pd.GetPieceSize()
fileSize := fileInfo.Size()
// Read the entire file if specified -1 but make sure we do it from the correct offset // Read the entire file if specified -1 but make sure we do it from the correct offset
if pd.GetPieceSize() <= -1 || totalToRead+pd.GetOffset() > fileSize { if pd.GetPieceSize() <= -1 || totalToRead+pd.GetOffset() > fileSize {

View File

@ -66,8 +66,8 @@ type Storage interface {
// Close closes the underlying database. // Close closes the underlying database.
Close() error Close() error
// PiecePath returns path of the specified piece on disk. // Size returns size of the piece
PiecePath(pieceID string) (string, error) Size(pieceID string) (int64, error)
// Info returns the current status of the disk. // Info returns the current status of the disk.
Info() (pstore.DiskInfo, error) Info() (pstore.DiskInfo, error)
} }
@ -184,12 +184,7 @@ func (s *Server) Piece(ctx context.Context, in *pb.PieceId) (*pb.PieceSummary, e
return nil, err return nil, err
} }
path, err := s.storage.PiecePath(id) size, err := s.storage.Size(id)
if err != nil {
return nil, err
}
fileInfo, err := os.Stat(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -201,7 +196,7 @@ func (s *Server) Piece(ctx context.Context, in *pb.PieceId) (*pb.PieceSummary, e
} }
s.log.Info("Successfully retrieved meta", zap.String("Piece ID", in.GetId())) s.log.Info("Successfully retrieved meta", zap.String("Piece ID", in.GetId()))
return &pb.PieceSummary{Id: in.GetId(), PieceSize: fileInfo.Size(), ExpirationUnixSec: ttl}, nil return &pb.PieceSummary{Id: in.GetId(), PieceSize: size, ExpirationUnixSec: ttl}, nil
} }
// Stats returns current statistics about the server. // Stats returns current statistics about the server.
@ -278,13 +273,10 @@ func (s *Server) Delete(ctx context.Context, in *pb.PieceDelete) (*pb.PieceDelet
} }
func (s *Server) deleteByID(id string) error { func (s *Server) deleteByID(id string) error {
if err := s.storage.Delete(id); err != nil { return errs.Combine(
return err s.DB.DeleteTTLByID(id),
} s.storage.Delete(id),
if err := s.DB.DeleteTTLByID(id); err != nil { )
return err
}
return nil
} }
func (s *Server) verifySignature(ctx context.Context, rba *pb.Order) error { func (s *Server) verifySignature(ctx context.Context, rba *pb.Order) error {

View File

@ -13,17 +13,13 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gogo/protobuf/proto"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -44,71 +40,55 @@ func TestPiece(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
snID, upID := newTestID(ctx, t), newTestID(ctx, t) snID, upID := newTestIdentity(ctx, t), newTestIdentity(ctx, t)
s, c, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{}) server, client, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{})
defer cleanup() defer cleanup()
namespacedID, err := getNamespacedPieceID([]byte("11111111111111111111"), snID.ID.Bytes()) namespacedID, err := getNamespacedPieceID([]byte("11111111111111111111"), snID.ID.Bytes())
require.NoError(t, err) require.NoError(t, err)
if err := writeFile(s, namespacedID); err != nil { if err := writeFile(server, namespacedID); err != nil {
t.Errorf("Error: %v\nCould not create test piece", err) t.Errorf("Error: %v\nCould not create test piece", err)
return return
} }
defer func() { _ = s.storage.Delete(namespacedID) }() defer func() { _ = server.storage.Delete(namespacedID) }()
// set up test cases // set up test cases
tests := []struct { tests := []struct {
id string id string
size int64 size int64
expiration int64 expiration int64
err string errContains string
}{ }{
{ // should successfully retrieve piece meta-data { // should successfully retrieve piece meta-data
id: "11111111111111111111", id: "11111111111111111111",
size: 5, size: 5,
expiration: 9999999999, expiration: 9999999999,
err: "",
}, },
{ // server should err with nonexistent file { // server should err with nonexistent file
id: "22222222222222222222", id: "22222222222222222222",
size: 5, size: 5,
expiration: 9999999999, expiration: 9999999999,
err: fmt.Sprintf("rpc error: code = Unknown desc = stat %s: no such file or directory", func() string { errContains: "piecestore error", // TODO: fix for i18n, these message can vary per OS
namespacedID, err := getNamespacedPieceID([]byte("22222222222222222222"), snID.ID.Bytes())
require.NoError(t, err)
path, _ := s.storage.PiecePath(namespacedID)
return path
}()),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
namespacedID, err := getNamespacedPieceID([]byte(tt.id), snID.ID.Bytes()) namespacedID, err := getNamespacedPieceID([]byte(tt.id), snID.ID.Bytes())
require.NoError(t, err) require.NoError(t, err)
// simulate piece TTL entry // simulate piece TTL entry
_, err = s.DB.DB.Exec(fmt.Sprintf(`INSERT INTO ttl (id, created, expires) VALUES ("%s", "%d", "%d")`, namespacedID, 1234567890, tt.expiration)) require.NoError(t, server.DB.AddTTL(namespacedID, tt.expiration, tt.size))
require.NoError(t, err) defer func() { require.NoError(t, server.DB.DeleteTTLByID(namespacedID)) }()
defer func() {
_, err := s.DB.DB.Exec(fmt.Sprintf(`DELETE FROM ttl WHERE id="%s"`, namespacedID))
require.NoError(t, err)
}()
req := &pb.PieceId{Id: tt.id, SatelliteId: snID.ID} req := &pb.PieceId{Id: tt.id, SatelliteId: snID.ID}
resp, err := c.Piece(ctx, req) resp, err := client.Piece(ctx, req)
if tt.err != "" { if tt.errContains != "" {
require.NotNil(t, err) require.NotNil(t, err)
if runtime.GOOS == "windows" && strings.Contains(tt.err, "no such file or directory") { require.Contains(t, err.Error(), tt.errContains)
//TODO (windows): ignoring for windows due to different underlying error
return
}
require.Equal(t, tt.err, err.Error())
return return
} }
@ -128,26 +108,26 @@ func TestRetrieve(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
snID, upID := newTestID(ctx, t), newTestID(ctx, t) snID, upID := newTestIdentity(ctx, t), newTestIdentity(ctx, t)
s, c, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{}) server, client, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{})
defer cleanup() defer cleanup()
if err := writeFile(s, "11111111111111111111"); err != nil { if err := writeFile(server, "11111111111111111111"); err != nil {
t.Errorf("Error: %v\nCould not create test piece", err) t.Errorf("Error: %v\nCould not create test piece", err)
return return
} }
defer func() { _ = s.storage.Delete("11111111111111111111") }() defer func() { _ = server.storage.Delete("11111111111111111111") }()
// set up test cases // set up test cases
tests := []struct { tests := []struct {
id string id string
reqSize int64 reqSize int64
respSize int64 respSize int64
allocSize int64 allocSize int64
offset int64 offset int64
content []byte content []byte
err string errContains string
}{ }{
{ // should successfully retrieve data { // should successfully retrieve data
id: "11111111111111111111", id: "11111111111111111111",
@ -156,7 +136,6 @@ func TestRetrieve(t *testing.T) {
allocSize: 5, allocSize: 5,
offset: 0, offset: 0,
content: []byte("xyzwq"), content: []byte("xyzwq"),
err: "",
}, },
{ // should successfully retrieve data in customizeable increments { // should successfully retrieve data in customizeable increments
id: "11111111111111111111", id: "11111111111111111111",
@ -165,7 +144,6 @@ func TestRetrieve(t *testing.T) {
allocSize: 2, allocSize: 2,
offset: 0, offset: 0,
content: []byte("xyzwq"), content: []byte("xyzwq"),
err: "",
}, },
{ // should successfully retrieve data with lower allocations { // should successfully retrieve data with lower allocations
id: "11111111111111111111", id: "11111111111111111111",
@ -174,7 +152,6 @@ func TestRetrieve(t *testing.T) {
allocSize: 3, allocSize: 3,
offset: 0, offset: 0,
content: []byte("xyz"), content: []byte("xyz"),
err: "",
}, },
{ // should successfully retrieve data { // should successfully retrieve data
id: "11111111111111111111", id: "11111111111111111111",
@ -183,28 +160,24 @@ func TestRetrieve(t *testing.T) {
allocSize: 5, allocSize: 5,
offset: 0, offset: 0,
content: []byte("xyzwq"), content: []byte("xyzwq"),
err: "",
}, },
{ // server should err with invalid id { // server should err with invalid id
id: "123", id: "123",
reqSize: 5, reqSize: 5,
respSize: 5, respSize: 5,
allocSize: 5, allocSize: 5,
offset: 0, offset: 0,
content: []byte("xyzwq"), content: []byte("xyzwq"),
err: "rpc error: code = Unknown desc = piecestore error: invalid id length", errContains: "rpc error: code = Unknown desc = piecestore error: invalid id length",
}, },
{ // server should err with nonexistent file { // server should err with nonexistent file
id: "22222222222222222222", id: "22222222222222222222",
reqSize: 5, reqSize: 5,
respSize: 5, respSize: 5,
allocSize: 5, allocSize: 5,
offset: 0, offset: 0,
content: []byte("xyzwq"), content: []byte("xyzwq"),
err: fmt.Sprintf("rpc error: code = Unknown desc = retrieve error: stat %s: no such file or directory", func() string { errContains: "piecestore error",
path, _ := s.storage.PiecePath("22222222222222222222")
return path
}()),
}, },
{ // server should return expected content and respSize with offset and excess reqSize { // server should return expected content and respSize with offset and excess reqSize
id: "11111111111111111111", id: "11111111111111111111",
@ -213,7 +186,6 @@ func TestRetrieve(t *testing.T) {
allocSize: 5, allocSize: 5,
offset: 1, offset: 1,
content: []byte("yzwq"), content: []byte("yzwq"),
err: "",
}, },
{ // server should return expected content with reduced reqSize { // server should return expected content with reduced reqSize
id: "11111111111111111111", id: "11111111111111111111",
@ -222,13 +194,12 @@ func TestRetrieve(t *testing.T) {
allocSize: 5, allocSize: 5,
offset: 0, offset: 0,
content: []byte("xyzw"), content: []byte("xyzw"),
err: "",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
stream, err := c.Retrieve(ctx) stream, err := client.Retrieve(ctx)
require.NoError(t, err) require.NoError(t, err)
// send piece database // send piece database
@ -241,6 +212,7 @@ func TestRetrieve(t *testing.T) {
totalAllocated := int64(0) totalAllocated := int64(0)
var data string var data string
var totalRetrieved = int64(0) var totalRetrieved = int64(0)
var resp *pb.PieceRetrievalStream var resp *pb.PieceRetrievalStream
for totalAllocated < tt.respSize { for totalAllocated < tt.respSize {
// Send bandwidth bandwidthAllocation // Send bandwidth bandwidthAllocation
@ -253,24 +225,18 @@ func TestRetrieve(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
resp, err = stream.Recv() resp, err = stream.Recv()
if tt.err != "" { if tt.errContains != "" {
require.NotNil(t, err) require.NotNil(t, err)
if runtime.GOOS == "windows" && strings.Contains(tt.err, "no such file or directory") { require.Contains(t, err.Error(), tt.errContains)
//TODO (windows): ignoring for windows due to different underlying error
return
}
require.Equal(t, tt.err, err.Error())
return return
} }
require.NotNil(t, resp)
assert.NoError(t, err) assert.NoError(t, err)
data = fmt.Sprintf("%s%s", data, string(resp.GetContent())) data = fmt.Sprintf("%s%s", data, string(resp.GetContent()))
totalRetrieved += resp.GetPieceSize() totalRetrieved += resp.GetPieceSize()
} }
assert.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tt.respSize, totalRetrieved) assert.Equal(t, tt.respSize, totalRetrieved)
assert.Equal(t, string(tt.content), data) assert.Equal(t, string(tt.content), data)
}) })
@ -281,7 +247,7 @@ func TestStore(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
satID := newTestID(ctx, t) satID := newTestIdentity(ctx, t)
tests := []struct { tests := []struct {
id string id string
@ -317,15 +283,14 @@ func TestStore(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
snID, upID := newTestID(ctx, t), newTestID(ctx, t) snID, upID := newTestIdentity(ctx, t), newTestIdentity(ctx, t)
s, c, cleanup := NewTest(ctx, t, snID, upID, tt.whitelist) server, client, cleanup := NewTest(ctx, t, snID, upID, tt.whitelist)
defer cleanup() defer cleanup()
db := s.DB.DB
sum := sha256.Sum256(tt.content) sum := sha256.Sum256(tt.content)
expectedHash := sum[:] expectedHash := sum[:]
stream, err := c.Store(ctx) stream, err := client.Store(ctx)
require.NoError(t, err) require.NoError(t, err)
// Create Bandwidth Allocation Data // Create Bandwidth Allocation Data
@ -358,42 +323,31 @@ func TestStore(t *testing.T) {
return return
} }
defer func() {
_, err := db.Exec(fmt.Sprintf(`DELETE FROM ttl WHERE id="%s"`, tt.id))
require.NoError(t, err)
}()
// check db to make sure agreement and signature were stored correctly
rows, err := db.Query(`SELECT agreement, signature FROM bandwidth_agreements`)
require.NoError(t, err) require.NoError(t, err)
if assert.NotNil(t, resp) {
defer func() { require.NoError(t, rows.Close()) }() assert.Equal(t, tt.message, resp.Message)
for rows.Next() { assert.Equal(t, tt.totalReceived, resp.TotalReceived)
var agreement, signature []byte assert.Equal(t, expectedHash, resp.SignedHash.Hash)
err = rows.Scan(&agreement, &signature) assert.NotNil(t, resp.SignedHash.Signature)
require.NoError(t, err)
rba := &pb.Order{}
require.NoError(t, proto.Unmarshal(agreement, rba))
require.Equal(t, msg.BandwidthAllocation.GetSignature(), signature)
require.True(t, pb.Equal(pba, &rba.PayerAllocation))
require.Equal(t, int64(len(tt.content)), rba.Total)
} }
err = rows.Err()
allocations, err := server.DB.GetBandwidthAllocationBySignature(rba.Signature)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp) require.NotNil(t, allocations)
require.Equal(t, tt.message, resp.Message) for _, allocation := range allocations {
require.Equal(t, tt.totalReceived, resp.TotalReceived) require.Equal(t, msg.BandwidthAllocation.GetSignature(), allocation.Signature)
require.Equal(t, expectedHash, resp.SignedHash.Hash) require.Equal(t, int64(len(tt.content)), rba.Total)
require.NotNil(t, resp.SignedHash.Signature) }
}) })
} }
} }
func TestPbaValidation(t *testing.T) { func TestPbaValidation(t *testing.T) {
t.Skip("broken")
ctx := testcontext.New(t) ctx := testcontext.New(t)
snID, upID := newTestID(ctx, t), newTestID(ctx, t) snID, upID := newTestIdentity(ctx, t), newTestIdentity(ctx, t)
satID1, satID2, satID3 := newTestID(ctx, t), newTestID(ctx, t), newTestID(ctx, t) satID1, satID2, satID3 := newTestIdentity(ctx, t), newTestIdentity(ctx, t), newTestIdentity(ctx, t)
defer ctx.Cleanup() defer ctx.Cleanup()
tests := []struct { tests := []struct {
@ -435,10 +389,10 @@ func TestPbaValidation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
s, c, cleanup := NewTest(ctx, t, snID, upID, tt.whitelist) server, client, cleanup := NewTest(ctx, t, snID, upID, tt.whitelist)
defer cleanup() defer cleanup()
stream, err := c.Store(ctx) stream, err := client.Store(ctx)
require.NoError(t, err) require.NoError(t, err)
// Create Bandwidth Allocation Data // Create Bandwidth Allocation Data
@ -453,7 +407,7 @@ func TestPbaValidation(t *testing.T) {
} }
//cleanup incase tests previously paniced //cleanup incase tests previously paniced
_ = s.storage.Delete("99999999999999999999") _ = server.storage.Delete("99999999999999999999")
// Write the buffer to the stream we opened earlier // Write the buffer to the stream we opened earlier
err = stream.Send(&pb.PieceStore{ err = stream.Send(&pb.PieceStore{
PieceData: &pb.PieceStore_PieceData{Id: "99999999999999999999", ExpirationUnixSec: 9999999999}, PieceData: &pb.PieceStore_PieceData{Id: "99999999999999999999", ExpirationUnixSec: 9999999999},
@ -468,10 +422,8 @@ func TestPbaValidation(t *testing.T) {
} }
_, err = stream.CloseAndRecv() _, err = stream.CloseAndRecv()
if err != nil { if tt.err != "" {
//require.NotNil(t, err) require.NotNil(t, err)
t.Log("Expected err string", tt.err)
t.Log("Actual err.Error:", err.Error())
require.Equal(t, tt.err, err.Error()) require.Equal(t, tt.err, err.Error())
return return
} }
@ -483,89 +435,63 @@ func TestDelete(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
snID, upID := newTestID(ctx, t), newTestID(ctx, t) snID, upID := newTestIdentity(ctx, t), newTestIdentity(ctx, t)
s, c, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{}) server, client, cleanup := NewTest(ctx, t, snID, upID, []storj.NodeID{})
defer cleanup() defer cleanup()
db := s.DB.DB pieceID := "11111111111111111111"
namespacedID, err := getNamespacedPieceID([]byte(pieceID), snID.ID.Bytes())
require.NoError(t, err)
// set up test cases // simulate piece stored with storagenode
tests := []struct { if err := writeFile(server, namespacedID); err != nil {
id string t.Errorf("Error: %v\nCould not create test piece", err)
message string return
err string
}{
{ // should successfully delete data
id: "11111111111111111111",
message: "OK",
err: "",
},
{ // should return OK with nonexistent file
id: "22222222222222222223",
message: "OK",
err: "",
},
} }
require.NoError(t, server.DB.AddTTL(namespacedID, 1234567890, 1234567890))
defer func() { require.NoError(t, server.DB.DeleteTTLByID(namespacedID)) }()
for _, tt := range tests { resp, err := client.Delete(ctx, &pb.PieceDelete{
t.Run("", func(t *testing.T) { Id: pieceID,
// simulate piece stored with storagenode SatelliteId: snID.ID,
if err := writeFile(s, "11111111111111111111"); err != nil { })
t.Errorf("Error: %v\nCould not create test piece", err) require.NoError(t, err)
return require.Equal(t, "OK", resp.GetMessage())
}
// simulate piece TTL entry // check if file was indeed deleted
_, err := db.Exec(fmt.Sprintf(`INSERT INTO ttl (id, created, expires) VALUES ("%s", "%d", "%d")`, tt.id, 1234567890, 1234567890)) _, err = server.storage.Size(namespacedID)
require.NoError(t, err) require.Error(t, err)
defer func() { resp, err = client.Delete(ctx, &pb.PieceDelete{
_, err := db.Exec(fmt.Sprintf(`DELETE FROM ttl WHERE id="%s"`, tt.id)) Id: "22222222222222",
require.NoError(t, err) SatelliteId: snID.ID,
}() })
require.NoError(t, err)
defer func() { require.Equal(t, "OK", resp.GetMessage())
require.NoError(t, s.storage.Delete("11111111111111111111"))
}()
req := &pb.PieceDelete{Id: tt.id}
resp, err := c.Delete(ctx, req)
if tt.err != "" {
require.Equal(t, tt.err, err.Error())
return
}
require.NoError(t, err)
require.Equal(t, tt.message, resp.GetMessage())
// if test passes, check if file was indeed deleted
filePath, err := s.storage.PiecePath(tt.id)
require.NoError(t, err)
if _, err = os.Stat(filePath); os.IsExist(err) {
t.Errorf("File not deleted")
return
}
})
}
} }
func NewTest(ctx context.Context, t *testing.T, snID, upID *identity.FullIdentity, func NewTest(ctx context.Context, t *testing.T, snID, upID *identity.FullIdentity, ids []storj.NodeID) (*Server, pb.PieceStoreRoutesClient, func()) {
ids []storj.NodeID) (*Server, pb.PieceStoreRoutesClient, func()) {
//init ps server backend //init ps server backend
tmp, err := ioutil.TempDir("", "storj-piecestore") tmp, err := ioutil.TempDir("", "storj-piecestore")
require.NoError(t, err) require.NoError(t, err)
tempDBPath := filepath.Join(tmp, "test.db") tempDBPath := filepath.Join(tmp, "test.db")
tempDir := filepath.Join(tmp, "test-data", "3000") tempDir := filepath.Join(tmp, "test-data", "3000")
storage := pstore.NewStorage(tempDir) storage := pstore.NewStorage(tempDir)
psDB, err := psdb.Open(tempDBPath) psDB, err := psdb.Open(tempDBPath)
require.NoError(t, err) require.NoError(t, err)
err = psDB.Migration().Run(zap.NewNop(), psDB)
err = psDB.Migration().Run(zaptest.NewLogger(t), psDB)
require.NoError(t, err) require.NoError(t, err)
whitelist := make(map[storj.NodeID]crypto.PublicKey) whitelist := make(map[storj.NodeID]crypto.PublicKey)
for _, id := range ids { for _, id := range ids {
whitelist[id] = nil whitelist[id] = nil
} }
psServer := &Server{ psServer := &Server{
log: zaptest.NewLogger(t), log: zaptest.NewLogger(t),
storage: storage, storage: storage,
@ -575,16 +501,20 @@ func NewTest(ctx context.Context, t *testing.T, snID, upID *identity.FullIdentit
totalBwAllocated: math.MaxInt64, totalBwAllocated: math.MaxInt64,
whitelist: whitelist, whitelist: whitelist,
} }
//init ps server grpc //init ps server grpc
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
publicConfig := server.Config{Address: "127.0.0.1:0"} publicConfig := server.Config{Address: "127.0.0.1:0"}
publicOptions, err := tlsopts.NewOptions(snID, publicConfig.Config) publicOptions, err := tlsopts.NewOptions(snID, publicConfig.Config)
require.NoError(t, err) require.NoError(t, err)
grpcServer, err := server.New(publicOptions, listener, nil) grpcServer, err := server.New(publicOptions, listener, nil)
require.NoError(t, err) require.NoError(t, err)
pb.RegisterPieceStoreRoutesServer(grpcServer.GRPC(), psServer) pb.RegisterPieceStoreRoutesServer(grpcServer.GRPC(), psServer)
go func() { require.NoError(t, grpcServer.Run(ctx)) }() go func() { require.NoError(t, grpcServer.Run(ctx)) }() // TODO: wait properly for server termination
//init client //init client
tlsOptions, err := tlsopts.NewOptions(upID, tlsopts.Config{}) tlsOptions, err := tlsopts.NewOptions(upID, tlsopts.Config{})
@ -603,7 +533,7 @@ func NewTest(ctx context.Context, t *testing.T, snID, upID *identity.FullIdentit
return psServer, psClient, cleanup return psServer, psClient, cleanup
} }
func newTestID(ctx context.Context, t *testing.T) *identity.FullIdentity { func newTestIdentity(ctx context.Context, t *testing.T) *identity.FullIdentity {
id, err := testidentity.NewTestIdentity(ctx) id, err := testidentity.NewTestIdentity(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -15,6 +15,14 @@ import (
"storj.io/storj/pkg/ranger" "storj.io/storj/pkg/ranger"
) )
// IDLength -- Minimum ID length
const IDLength = 20
// Errors
var (
Error = errs.Class("piecestore error")
)
// Storage stores piecestore pieces // Storage stores piecestore pieces
type Storage struct { type Storage struct {
dir string dir string
@ -38,25 +46,15 @@ func (storage *Storage) Info() (DiskInfo, error) {
rootPath := filepath.Dir(filepath.Clean(storage.dir)) rootPath := filepath.Dir(filepath.Clean(storage.dir))
diskSpace, err := disk.Usage(rootPath) diskSpace, err := disk.Usage(rootPath)
if err != nil { if err != nil {
return DiskInfo{}, err return DiskInfo{}, Error.Wrap(err)
} }
return DiskInfo{ return DiskInfo{
AvailableSpace: int64(diskSpace.Free), AvailableSpace: int64(diskSpace.Free),
}, nil }, nil
} }
// IDLength -- Minimum ID length // piecePath creates piece storage path from id and dir
const IDLength = 20 func (storage *Storage) piecePath(pieceID string) (string, error) {
// Errors
var (
Error = errs.Class("piecestore error")
MkDir = errs.Class("piecestore MkdirAll")
Open = errs.Class("piecestore OpenFile")
)
// PiecePath creates piece storage path from id and dir
func (storage *Storage) PiecePath(pieceID string) (string, error) {
if len(pieceID) < IDLength { if len(pieceID) < IDLength {
return "", Error.New("invalid id length") return "", Error.New("invalid id length")
} }
@ -64,32 +62,49 @@ func (storage *Storage) PiecePath(pieceID string) (string, error) {
return filepath.Join(storage.dir, folder1, folder2, filename), nil return filepath.Join(storage.dir, folder1, folder2, filename), nil
} }
// Size returns piece size.
func (storage *Storage) Size(pieceID string) (int64, error) {
path, err := storage.piecePath(pieceID)
if err != nil {
return 0, err
}
fileInfo, err := os.Stat(path)
if err != nil {
return 0, Error.Wrap(err)
}
return fileInfo.Size(), nil
}
// Writer returns a writer that can be used to store piece. // Writer returns a writer that can be used to store piece.
func (storage *Storage) Writer(pieceID string) (io.WriteCloser, error) { func (storage *Storage) Writer(pieceID string) (io.WriteCloser, error) {
path, err := storage.PiecePath(pieceID) path, err := storage.piecePath(pieceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = os.MkdirAll(filepath.Dir(path), 0700); err != nil { if err = os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return nil, MkDir.Wrap(err) return nil, Error.Wrap(err)
} }
file, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) file, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600)
if err != nil { if err != nil {
return nil, Open.Wrap(err) return nil, Error.Wrap(err)
} }
return file, nil return file, nil
} }
// Reader returns a reader for the specified piece at the location // Reader returns a reader for the specified piece at the location
func (storage *Storage) Reader(ctx context.Context, pieceID string, offset int64, length int64) (io.ReadCloser, error) { func (storage *Storage) Reader(ctx context.Context, pieceID string, offset int64, length int64) (io.ReadCloser, error) {
path, err := storage.PiecePath(pieceID) path, err := storage.piecePath(pieceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
info, err := os.Stat(path) info, err := os.Stat(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if offset >= info.Size() || offset < 0 { if offset >= info.Size() || offset < 0 {
return nil, Error.New("invalid offset: %v", offset) return nil, Error.New("invalid offset: %v", offset)
} }
@ -100,23 +115,25 @@ func (storage *Storage) Reader(ctx context.Context, pieceID string, offset int64
if info.Size() < offset+length { if info.Size() < offset+length {
length = info.Size() - offset length = info.Size() - offset
} }
rr, err := ranger.FileRanger(path) rr, err := ranger.FileRanger(path)
if err != nil { if err != nil {
return nil, err return nil, Error.Wrap(err)
} }
return rr.Range(ctx, offset, length)
r, err := rr.Range(ctx, offset, length)
return r, Error.Wrap(err)
} }
// Delete deletes piece from storage // Delete deletes piece from storage
func (storage *Storage) Delete(pieceID string) error { func (storage *Storage) Delete(pieceID string) error {
path, err := storage.PiecePath(pieceID) path, err := storage.piecePath(pieceID)
if err != nil { if err != nil {
return err return Error.Wrap(err)
} }
err = os.Remove(path) err = os.Remove(path)
if os.IsNotExist(err) { if os.IsNotExist(err) {
err = nil err = nil
} }
return err return Error.Wrap(err)
} }

View File

@ -16,7 +16,6 @@ import (
"storj.io/storj/pkg/kademlia" "storj.io/storj/pkg/kademlia"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/peertls/tlsopts"
pstore "storj.io/storj/pkg/piecestore"
"storj.io/storj/pkg/piecestore/psserver" "storj.io/storj/pkg/piecestore/psserver"
"storj.io/storj/pkg/piecestore/psserver/agreementsender" "storj.io/storj/pkg/piecestore/psserver/agreementsender"
"storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/piecestore/psserver/psdb"
@ -33,8 +32,8 @@ type DB interface {
// Close closes the database // Close closes the database
Close() error Close() error
Storage() psserver.Storage
// TODO: use better interfaces // TODO: use better interfaces
Storage() *pstore.Storage
PSDB() *psdb.DB PSDB() *psdb.DB
RoutingTable() (kdb, ndb storage.KeyValueStore) RoutingTable() (kdb, ndb storage.KeyValueStore)
} }

View File

@ -9,6 +9,7 @@ import (
"storj.io/storj/pkg/kademlia" "storj.io/storj/pkg/kademlia"
pstore "storj.io/storj/pkg/piecestore" pstore "storj.io/storj/pkg/piecestore"
"storj.io/storj/pkg/piecestore/psserver"
"storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/piecestore/psserver/psdb"
"storj.io/storj/storage" "storj.io/storj/storage"
"storj.io/storj/storage/boltdb" "storj.io/storj/storage/boltdb"
@ -29,7 +30,7 @@ type Config struct {
// DB contains access to different database tables // DB contains access to different database tables
type DB struct { type DB struct {
log *zap.Logger log *zap.Logger
storage *pstore.Storage storage psserver.Storage
psdb *psdb.DB psdb *psdb.DB
kdb, ndb storage.KeyValueStore kdb, ndb storage.KeyValueStore
} }
@ -93,7 +94,7 @@ func (db *DB) Close() error {
} }
// Storage returns piecestore location // Storage returns piecestore location
func (db *DB) Storage() *pstore.Storage { func (db *DB) Storage() psserver.Storage {
return db.storage return db.storage
} }