storj/pkg/piecestore/psserver/readerwriter.go
Maximillian von Briesen a6c7306350
Cut off piecestore Puts if they exceed alloced bandwidth/space (#819)
* add bandwidth/storage limits to StreamWriter

* add StreamWriter tests for bandwidth/storage limits
2018-12-12 14:14:51 -05:00

108 lines
2.7 KiB
Go

// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package psserver
import (
"github.com/gogo/protobuf/proto"
"github.com/zeebo/errs"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/utils"
)
// StreamWriterError is a type of error for failures in StreamWriter
var StreamWriterError = errs.Class("stream writer error")
// StreamWriter -- Struct for writing piece to server upload stream
type StreamWriter struct {
server *Server
stream pb.PieceStoreRoutes_RetrieveServer
}
// NewStreamWriter returns a new StreamWriter
func NewStreamWriter(s *Server, stream pb.PieceStoreRoutes_RetrieveServer) *StreamWriter {
return &StreamWriter{server: s, stream: stream}
}
// Write -- Write method for piece upload to stream for Server.Retrieve
func (s *StreamWriter) Write(b []byte) (int, error) {
// Write the buffer to the stream we opened earlier
if err := s.stream.Send(&pb.PieceRetrievalStream{PieceSize: int64(len(b)), Content: b}); err != nil {
return 0, err
}
return len(b), nil
}
// StreamReader is a struct for Retrieving data from server
type StreamReader struct {
src *utils.ReaderSource
bandwidthAllocation *pb.RenterBandwidthAllocation
currentTotal int64
bandwidthRemaining int64
spaceRemaining int64
sofar int64
}
// NewStreamReader returns a new StreamReader for Server.Store
func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer, bandwidthRemaining, spaceRemaining int64) *StreamReader {
sr := &StreamReader{
bandwidthRemaining: bandwidthRemaining,
spaceRemaining: spaceRemaining,
}
sr.src = utils.NewReaderSource(func() ([]byte, error) {
recv, err := stream.Recv()
if err != nil {
return nil, err
}
pd := recv.GetPieceData()
ba := recv.GetBandwidthAllocation()
if ba != nil {
if err = s.verifySignature(stream.Context(), ba); err != nil {
return nil, err
}
deserializedData := &pb.RenterBandwidthAllocation_Data{}
err = proto.Unmarshal(ba.GetData(), deserializedData)
if err != nil {
return nil, err
}
// Update bandwidthallocation to be stored
if deserializedData.GetTotal() > sr.currentTotal {
sr.bandwidthAllocation = ba
sr.currentTotal = deserializedData.GetTotal()
}
}
return pd.GetContent(), nil
})
return sr
}
// Read -- Read method for piece download from stream
func (s *StreamReader) Read(b []byte) (int, error) {
if s.sofar >= s.bandwidthRemaining {
return 0, StreamWriterError.New("out of bandwidth")
}
if s.sofar >= s.spaceRemaining {
return 0, StreamWriterError.New("out of space")
}
n, err := s.src.Read(b)
s.sofar += int64(n)
if err != nil {
return n, err
}
if s.sofar >= s.spaceRemaining {
return n, StreamWriterError.New("out of space")
}
return n, nil
}