Verifier should use payer bandwidth alloc from satellite (#960)
* Verifier should use payer bandwidth alloc from satellite * unit test added * fix typo * review comments applied * fix renamed field
This commit is contained in:
parent
6372190873
commit
bacc1b13b4
@ -22,6 +22,7 @@ import (
|
||||
type Stripe struct {
|
||||
Index int
|
||||
Segment *pb.Pointer
|
||||
PBA *pb.PayerBandwidthAllocation
|
||||
Authorization *pb.SignedMessage
|
||||
}
|
||||
|
||||
@ -74,7 +75,7 @@ func (cursor *Cursor) NextStripe(ctx context.Context) (stripe *Stripe, err error
|
||||
}
|
||||
|
||||
// get pointer info
|
||||
pointer, _, _, err := cursor.pointers.Get(ctx, path)
|
||||
pointer, _, pba, err := cursor.pointers.Get(ctx, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -99,8 +100,7 @@ func (cursor *Cursor) NextStripe(ctx context.Context) (stripe *Stripe, err error
|
||||
}
|
||||
|
||||
authorization := cursor.pointers.SignedMessage()
|
||||
|
||||
return &Stripe{Index: index, Segment: pointer, Authorization: authorization}, nil
|
||||
return &Stripe{Index: index, Segment: pointer, PBA: pba, Authorization: authorization}, nil
|
||||
}
|
||||
|
||||
func makeErasureScheme(rs *pb.RedundancyScheme) (eestream.ErasureScheme, error) {
|
||||
|
@ -101,8 +101,7 @@ func (service *Service) process(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
authorization := service.Cursor.pointers.SignedMessage()
|
||||
verifiedNodes, err := service.Verifier.verify(ctx, stripe.Index, stripe.Segment, authorization)
|
||||
verifiedNodes, err := service.Verifier.verify(ctx, stripe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -7,9 +7,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/vivint/infectious"
|
||||
monkit "gopkg.in/spacemonkeygo/monkit.v2"
|
||||
|
||||
@ -36,7 +34,8 @@ type Verifier struct {
|
||||
}
|
||||
|
||||
type downloader interface {
|
||||
DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error)
|
||||
DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int, pba *pb.PayerBandwidthAllocation,
|
||||
authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error)
|
||||
}
|
||||
|
||||
// defaultDownloader downloads shares from networked storage nodes
|
||||
@ -59,7 +58,7 @@ func NewVerifier(transport transport.Client, overlay overlay.Client, id provider
|
||||
|
||||
// getShare use piece store clients to download shares from a given node
|
||||
func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize, pieceNumber int,
|
||||
id psclient.PieceID, pieceSize int64, fromNode *pb.Node, authorization *pb.SignedMessage) (s share, err error) {
|
||||
id psclient.PieceID, pieceSize int64, fromNode *pb.Node, pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (s share, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
fromNode.Type.DPanicOnInvalid("audit getShare")
|
||||
ps, err := psclient.NewPSClient(ctx, d.transport, fromNode, 0)
|
||||
@ -72,20 +71,6 @@ func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize
|
||||
return s, err
|
||||
}
|
||||
|
||||
allocationData := &pb.PayerBandwidthAllocation_Data{
|
||||
Action: pb.PayerBandwidthAllocation_GET,
|
||||
CreatedUnixSec: time.Now().Unix(),
|
||||
}
|
||||
|
||||
serializedAllocation, err := proto.Marshal(allocationData)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
pba := &pb.PayerBandwidthAllocation{
|
||||
Data: serializedAllocation,
|
||||
}
|
||||
|
||||
rr, err := ps.Get(ctx, derivedPieceID, pieceSize, pba, authorization)
|
||||
if err != nil {
|
||||
return s, err
|
||||
@ -115,7 +100,7 @@ func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize
|
||||
|
||||
// Download Shares downloads shares from the nodes where remote pieces are located
|
||||
func (d *defaultDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer,
|
||||
stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
|
||||
stripeIndex int, pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var nodeIds storj.NodeIDList
|
||||
@ -142,7 +127,7 @@ func (d *defaultDownloader) DownloadShares(ctx context.Context, pointer *pb.Poin
|
||||
paddedSize := calcPadded(pointer.GetSegmentSize(), shareSize)
|
||||
pieceSize := paddedSize / int64(pointer.Remote.Redundancy.GetMinReq())
|
||||
|
||||
s, err := d.getShare(ctx, stripeIndex, shareSize, int(pieces[i].PieceNum), pieceID, pieceSize, node, authorization)
|
||||
s, err := d.getShare(ctx, stripeIndex, shareSize, int(pieces[i].PieceNum), pieceID, pieceSize, node, pba, authorization)
|
||||
if err != nil {
|
||||
s = share{
|
||||
Error: err,
|
||||
@ -207,10 +192,10 @@ func calcPadded(size int64, blockSize int) int64 {
|
||||
}
|
||||
|
||||
// verify downloads shares then verifies the data correctness at the given stripe
|
||||
func (verifier *Verifier) verify(ctx context.Context, stripeIndex int, pointer *pb.Pointer, authorization *pb.SignedMessage) (verifiedNodes *RecordAuditsInfo, err error) {
|
||||
func (verifier *Verifier) verify(ctx context.Context, stripe *Stripe) (verifiedNodes *RecordAuditsInfo, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
shares, nodes, err := verifier.downloader.DownloadShares(ctx, pointer, stripeIndex, authorization)
|
||||
shares, nodes, err := verifier.downloader.DownloadShares(ctx, stripe.Segment, stripe.Index, stripe.PBA, stripe.Authorization)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -222,6 +207,7 @@ func (verifier *Verifier) verify(ctx context.Context, stripeIndex int, pointer *
|
||||
}
|
||||
}
|
||||
|
||||
pointer := stripe.Segment
|
||||
required := int(pointer.Remote.Redundancy.GetMinReq())
|
||||
total := int(pointer.Remote.Redundancy.GetTotal())
|
||||
pieceNums, err := auditShares(ctx, required, total, shares)
|
||||
|
@ -44,7 +44,7 @@ func TestPassingAudit(t *testing.T) {
|
||||
md := mockDownloader{shares: mockShares}
|
||||
verifier := &Verifier{downloader: &md}
|
||||
pointer := makePointer(tt.nodeAmt)
|
||||
verifiedNodes, err := verifier.verify(ctx, 6, pointer, nil)
|
||||
verifiedNodes, err := verifier.verify(ctx, &Stripe{Index: 6, Segment: pointer, PBA: nil, Authorization: nil})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -88,7 +88,7 @@ func TestSomeNodesPassAudit(t *testing.T) {
|
||||
md := mockDownloader{shares: mockShares}
|
||||
verifier := &Verifier{downloader: &md}
|
||||
pointer := makePointer(tt.nodeAmt)
|
||||
verifiedNodes, err := verifier.verify(ctx, 6, pointer, nil)
|
||||
verifiedNodes, err := verifier.verify(ctx, &Stripe{Index: 6, Segment: pointer, PBA: nil, Authorization: nil})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -200,8 +200,8 @@ func TestCalcPadded(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer,
|
||||
stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
|
||||
func (m *mockDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int,
|
||||
pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
|
||||
|
||||
nodes = make(map[int]*pb.Node, 30)
|
||||
|
||||
|
@ -72,6 +72,15 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer, bandwidt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pbaData := &pb.PayerBandwidthAllocation_Data{}
|
||||
if err = proto.Unmarshal(deserializedData.GetPayerAllocation().GetData(), pbaData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = s.verifyPayerAllocation(pbaData, pb.PayerBandwidthAllocation_PUT); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update bandwidthallocation to be stored
|
||||
if deserializedData.GetTotal() > sr.currentTotal {
|
||||
sr.bandwidthAllocation = ba
|
||||
|
@ -140,6 +140,22 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re
|
||||
return
|
||||
}
|
||||
|
||||
if allocData.GetPayerAllocation() == nil {
|
||||
allocationTracking.Fail(StoreError.New("no payer bandwidth allocation"))
|
||||
return
|
||||
}
|
||||
|
||||
pbaData := &pb.PayerBandwidthAllocation_Data{}
|
||||
if err = proto.Unmarshal(allocData.GetPayerAllocation().GetData(), pbaData); err != nil {
|
||||
allocationTracking.Fail(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.verifyPayerAllocation(pbaData, pb.PayerBandwidthAllocation_GET); err != nil {
|
||||
allocationTracking.Fail(err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: break when lastTotal >= allocData.GetPayer_allocation().GetData().GetMax_size()
|
||||
|
||||
if lastTotal > allocData.GetTotal() {
|
||||
|
@ -308,6 +308,18 @@ func (s *Server) verifySignature(ctx context.Context, ba *pb.RenterBandwidthAllo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) verifyPayerAllocation(pba *pb.PayerBandwidthAllocation_Data, action pb.PayerBandwidthAllocation_Action) (err error) {
|
||||
switch {
|
||||
case pba.SatelliteId.IsZero():
|
||||
return StoreError.New("payer bandwidth allocation: missing satellite id")
|
||||
case pba.UplinkId.IsZero():
|
||||
return StoreError.New("payer bandwidth allocation: missing uplink id")
|
||||
case pba.Action != action:
|
||||
return StoreError.New("payer bandwidth allocation: invalid action %v", pba.Action.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBeginningOfMonth() time.Time {
|
||||
t := time.Now()
|
||||
y, m, _ := t.Date()
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"storj.io/storj/internal/testidentity"
|
||||
"storj.io/storj/internal/teststorj"
|
||||
"storj.io/storj/pkg/pb"
|
||||
pstore "storj.io/storj/pkg/piecestore"
|
||||
"storj.io/storj/pkg/piecestore/psserver/psdb"
|
||||
@ -340,12 +341,20 @@ func TestStore(t *testing.T) {
|
||||
err = stream.Send(&pb.PieceStore{PieceData: &pb.PieceStore_PieceData{Id: tt.id, ExpirationUnixSec: tt.ttl}})
|
||||
assert.NoError(err)
|
||||
|
||||
pbad := &pb.PayerBandwidthAllocation_Data{
|
||||
SatelliteId: teststorj.NodeIDFromString("satelliteid"),
|
||||
UplinkId: teststorj.NodeIDFromString("uplinkid"),
|
||||
Action: pb.PayerBandwidthAllocation_PUT,
|
||||
}
|
||||
pbaData, err := proto.Marshal(pbad)
|
||||
assert.NoError(err)
|
||||
pba := &pb.PayerBandwidthAllocation{Data: pbaData}
|
||||
// Send Bandwidth Allocation Data
|
||||
msg := &pb.PieceStore{
|
||||
PieceData: &pb.PieceStore_PieceData{Content: tt.content},
|
||||
BandwidthAllocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
PayerAllocation: pba,
|
||||
Total: int64(len(tt.content)),
|
||||
}),
|
||||
},
|
||||
@ -394,7 +403,7 @@ func TestStore(t *testing.T) {
|
||||
err = proto.Unmarshal(agreement, decoded)
|
||||
assert.NoError(err)
|
||||
assert.Equal(msg.BandwidthAllocation.GetSignature(), signature)
|
||||
assert.Equal(&pb.PayerBandwidthAllocation{}, decoded.GetPayerAllocation())
|
||||
assert.True(proto.Equal(pba, decoded.GetPayerAllocation()))
|
||||
assert.Equal(int64(len(tt.content)), decoded.GetTotal())
|
||||
|
||||
}
|
||||
@ -407,6 +416,86 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPbaValidation(t *testing.T) {
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
||||
tests := []struct {
|
||||
satelliteID storj.NodeID
|
||||
uplinkID storj.NodeID
|
||||
action pb.PayerBandwidthAllocation_Action
|
||||
err string
|
||||
}{
|
||||
{ // missing satellite id
|
||||
satelliteID: storj.NodeID{},
|
||||
uplinkID: teststorj.NodeIDFromString("uplinkid"),
|
||||
action: pb.PayerBandwidthAllocation_PUT,
|
||||
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id",
|
||||
},
|
||||
{ // missing uplink id
|
||||
satelliteID: teststorj.NodeIDFromString("satelliteid"),
|
||||
uplinkID: storj.NodeID{},
|
||||
action: pb.PayerBandwidthAllocation_PUT,
|
||||
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id",
|
||||
},
|
||||
{ // wrong action type
|
||||
satelliteID: teststorj.NodeIDFromString("satelliteid"),
|
||||
uplinkID: teststorj.NodeIDFromString("uplinkid"),
|
||||
action: pb.PayerBandwidthAllocation_GET,
|
||||
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("should validate payer bandwidth allocation struct", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
stream, err := TS.c.Store(ctx)
|
||||
assert.NoError(err)
|
||||
|
||||
// Write the buffer to the stream we opened earlier
|
||||
err = stream.Send(&pb.PieceStore{PieceData: &pb.PieceStore_PieceData{Id: "99999999999999999999", ExpirationUnixSec: 9999999999}})
|
||||
assert.NoError(err)
|
||||
|
||||
pbad := &pb.PayerBandwidthAllocation_Data{
|
||||
SatelliteId: tt.satelliteID,
|
||||
UplinkId: tt.uplinkID,
|
||||
Action: tt.action,
|
||||
}
|
||||
pbaData, err := proto.Marshal(pbad)
|
||||
assert.NoError(err)
|
||||
pba := &pb.PayerBandwidthAllocation{Data: pbaData}
|
||||
// Send Bandwidth Allocation Data
|
||||
content := []byte("content")
|
||||
msg := &pb.PieceStore{
|
||||
PieceData: &pb.PieceStore_PieceData{Content: content},
|
||||
BandwidthAllocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: pba,
|
||||
Total: int64(len(content)),
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
s, err := cryptopasta.Sign(msg.BandwidthAllocation.Data, TS.k.(*ecdsa.PrivateKey))
|
||||
assert.NoError(err)
|
||||
msg.BandwidthAllocation.Signature = s
|
||||
|
||||
// Write the buffer to the stream we opened earlier
|
||||
err = stream.Send(msg)
|
||||
if err != io.EOF && err != nil {
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
_, err = stream.CloseAndRecv()
|
||||
if err != nil {
|
||||
//assert.NotNil(err)
|
||||
assert.Equal(tt.err, err.Error())
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
Loading…
Reference in New Issue
Block a user