pkg/piecestore: use readersource (#75)

* pkg/piecestore: use readersource

* pkg/piecestore: fix linting
This commit is contained in:
JT Olio 2018-06-05 08:00:48 -06:00 committed by GitHub
parent 3a9ec8b680
commit 6be2baf9f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 171 additions and 54 deletions

View File

@ -58,7 +58,7 @@ func (client *Client) RetrievePieceRequest(id string, offset int64, length int64
return nil, err
}
return &StreamReader{stream: stream}, nil
return NewStreamReader(stream), nil
}
// DeletePieceRequest -- Delete Piece From Server

View File

@ -7,6 +7,7 @@ import (
"fmt"
"log"
"storj.io/storj/pkg/utils"
pb "storj.io/storj/protos/piecestore"
)
@ -36,37 +37,29 @@ func (s *StreamWriter) Close() error {
return nil
}
// StreamReader -- Struct for reading piece download stream from server
// StreamReader is a struct for reading piece download stream from server
type StreamReader struct {
stream pb.PieceStoreRoutes_RetrieveClient
overflowData []byte
stream pb.PieceStoreRoutes_RetrieveClient
src *utils.ReaderSource
}
// NewStreamReader creates a StreamReader
func NewStreamReader(stream pb.PieceStoreRoutes_RetrieveClient) *StreamReader {
return &StreamReader{
stream: stream,
src: utils.NewReaderSource(func() ([]byte, error) {
msg, err := stream.Recv()
if err != nil {
return nil, err
}
return msg.Content, nil
}),
}
}
// Read -- Read method for piece download stream
func (s *StreamReader) Read(b []byte) (int, error) {
// Use overflow data if we have it
if len(s.overflowData) > 0 {
n := copy(b, s.overflowData) // Copy from overflow into buffer
s.overflowData = s.overflowData[n:] // Overflow is set to whatever remains
return n, nil
}
// Receive data from server stream
msg, err := s.stream.Recv()
if err != nil {
return 0, err
}
// Copy data into buffer
n := copy(b, msg.Content)
// If left over data save it into overflow variable for next read
if n < len(msg.Content) {
s.overflowData = b[len(b):]
}
return n, nil
return s.src.Read(b)
}
// Close -- Close Read Stream

View File

@ -4,6 +4,7 @@
package server
import (
"storj.io/storj/pkg/utils"
pb "storj.io/storj/protos/piecestore"
)
@ -22,35 +23,25 @@ func (s *StreamWriter) Write(b []byte) (int, error) {
return len(b), nil
}
// StreamReader -- Struct for Retrieving data from server
// StreamReader is a struct for Retrieving data from server
type StreamReader struct {
stream pb.PieceStoreRoutes_StoreServer
overflowData []byte
src *utils.ReaderSource
}
// NewStreamReader returns a new StreamReader
func NewStreamReader(stream pb.PieceStoreRoutes_StoreServer) *StreamReader {
return &StreamReader{
src: utils.NewReaderSource(func() ([]byte, error) {
msg, err := stream.Recv()
if err != nil {
return nil, err
}
return msg.Content, nil
}),
}
}
// Read -- Read method for piece download from stream
func (s *StreamReader) Read(b []byte) (int, error) {
// Use overflow data if we have it
if len(s.overflowData) > 0 {
n := copy(b, s.overflowData) // Copy from overflow into buffer
s.overflowData = s.overflowData[n:] // Overflow is set to whatever remains
return n, nil
}
// Receive data from server stream
msg, err := s.stream.Recv()
if err != nil {
return 0, err
}
// Copy data into buffer
n := copy(b, msg.Content)
// If left over data save it into overflow variable for next read
if n < len(msg.Content) {
s.overflowData = b[len(b):]
}
return n, nil
return s.src.Read(b)
}

View File

@ -47,7 +47,7 @@ func (s *Server) Store(stream pb.PieceStoreRoutes_StoreServer) error {
}
defer storeFile.Close()
reader := &StreamReader{stream: stream}
reader := NewStreamReader(stream)
total, err := io.Copy(storeFile, reader)
if err != nil {
return err

30
pkg/utils/io.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package utils
// ReaderSource takes a src func and turns it into an io.Reader
type ReaderSource struct {
src func() ([]byte, error)
buf []byte
err error
}
// NewReaderSource makes a new ReaderSource
func NewReaderSource(src func() ([]byte, error)) *ReaderSource {
return &ReaderSource{src: src}
}
// Read implements io.Reader
func (rs *ReaderSource) Read(p []byte) (n int, err error) {
if rs.err != nil {
return 0, rs.err
}
if len(rs.buf) == 0 {
rs.buf, rs.err = rs.src()
}
n = copy(p, rs.buf)
rs.buf = rs.buf[n:]
return n, rs.err
}

103
pkg/utils/io_test.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package utils
import (
"io"
"testing"
)
type testBytes [][]byte
func (t *testBytes) Next() (rv []byte, err error) {
if len(*t) > 0 {
rv, *t = (*t)[0], (*t)[1:]
return rv, nil
}
return nil, io.EOF
}
func TestReaderSource(t *testing.T) {
tb := testBytes([][]byte{
[]byte("hello there"),
[]byte("cool"),
[]byte("beans"),
})
rs := NewReaderSource(tb.Next)
buf := make([]byte, 1)
n, err := rs.Read(buf)
if n != 1 || err != nil || string(buf) != "h" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
buf = make([]byte, 10)
n, err = rs.Read(buf)
if n != 10 || err != nil || string(buf) != "ello there" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
buf = make([]byte, 5)
n, err = rs.Read(buf)
if n != 4 || err != nil || string(buf[:4]) != "cool" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
n, err = rs.Read(buf)
if n != 5 || err != nil || string(buf[:5]) != "beans" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
n, err = rs.Read(buf)
if n != 0 || err != io.EOF {
t.Fatalf("unexpected result: %d, %v", n, err)
}
}
type testBytesFastEOF [][]byte
func (t *testBytesFastEOF) Next() (rv []byte, err error) {
if len(*t) > 0 {
rv, *t = (*t)[0], (*t)[1:]
if len(*t) == 0 {
return rv, io.EOF
}
return rv, nil
}
return nil, io.EOF
}
func TestReaderSourceFastEOF(t *testing.T) {
tb := testBytesFastEOF([][]byte{
[]byte("hello there"),
[]byte("cool"),
[]byte("beans"),
})
rs := NewReaderSource(tb.Next)
buf := make([]byte, 1)
n, err := rs.Read(buf)
if n != 1 || err != nil || string(buf) != "h" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
buf = make([]byte, 10)
n, err = rs.Read(buf)
if n != 10 || err != nil || string(buf) != "ello there" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
buf = make([]byte, 5)
n, err = rs.Read(buf)
if n != 4 || err != nil || string(buf[:4]) != "cool" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
n, err = rs.Read(buf)
if n != 5 || err != io.EOF || string(buf[:5]) != "beans" {
t.Fatalf("unexpected result: %d, %v", n, err)
}
}