Verifier should use payer bandwidth alloc from satellite (#960)

* Verifier should use payer bandwidth alloc from satellite

* unit test added

* fix typo

* review comments applied

* fix renamed field
This commit is contained in:
Michal Niewrzal 2019-01-06 19:51:01 +01:00 committed by GitHub
parent 6372190873
commit bacc1b13b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 144 additions and 33 deletions

View File

@ -22,6 +22,7 @@ import (
type Stripe struct { type Stripe struct {
Index int Index int
Segment *pb.Pointer Segment *pb.Pointer
PBA *pb.PayerBandwidthAllocation
Authorization *pb.SignedMessage Authorization *pb.SignedMessage
} }
@ -74,7 +75,7 @@ func (cursor *Cursor) NextStripe(ctx context.Context) (stripe *Stripe, err error
} }
// get pointer info // get pointer info
pointer, _, _, err := cursor.pointers.Get(ctx, path) pointer, _, pba, err := cursor.pointers.Get(ctx, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -99,8 +100,7 @@ func (cursor *Cursor) NextStripe(ctx context.Context) (stripe *Stripe, err error
} }
authorization := cursor.pointers.SignedMessage() authorization := cursor.pointers.SignedMessage()
return &Stripe{Index: index, Segment: pointer, PBA: pba, Authorization: authorization}, nil
return &Stripe{Index: index, Segment: pointer, Authorization: authorization}, nil
} }
func makeErasureScheme(rs *pb.RedundancyScheme) (eestream.ErasureScheme, error) { func makeErasureScheme(rs *pb.RedundancyScheme) (eestream.ErasureScheme, error) {

View File

@ -101,8 +101,7 @@ func (service *Service) process(ctx context.Context) error {
return nil return nil
} }
authorization := service.Cursor.pointers.SignedMessage() verifiedNodes, err := service.Verifier.verify(ctx, stripe)
verifiedNodes, err := service.Verifier.verify(ctx, stripe.Index, stripe.Segment, authorization)
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,9 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"io" "io"
"time"
"github.com/gogo/protobuf/proto"
"github.com/vivint/infectious" "github.com/vivint/infectious"
monkit "gopkg.in/spacemonkeygo/monkit.v2" monkit "gopkg.in/spacemonkeygo/monkit.v2"
@ -36,7 +34,8 @@ type Verifier struct {
} }
type downloader interface { type downloader interface {
DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int, pba *pb.PayerBandwidthAllocation,
authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error)
} }
// defaultDownloader downloads shares from networked storage nodes // defaultDownloader downloads shares from networked storage nodes
@ -59,7 +58,7 @@ func NewVerifier(transport transport.Client, overlay overlay.Client, id provider
// getShare use piece store clients to download shares from a given node // getShare use piece store clients to download shares from a given node
func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize, pieceNumber int, func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize, pieceNumber int,
id psclient.PieceID, pieceSize int64, fromNode *pb.Node, authorization *pb.SignedMessage) (s share, err error) { id psclient.PieceID, pieceSize int64, fromNode *pb.Node, pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (s share, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
fromNode.Type.DPanicOnInvalid("audit getShare") fromNode.Type.DPanicOnInvalid("audit getShare")
ps, err := psclient.NewPSClient(ctx, d.transport, fromNode, 0) ps, err := psclient.NewPSClient(ctx, d.transport, fromNode, 0)
@ -72,20 +71,6 @@ func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize
return s, err return s, err
} }
allocationData := &pb.PayerBandwidthAllocation_Data{
Action: pb.PayerBandwidthAllocation_GET,
CreatedUnixSec: time.Now().Unix(),
}
serializedAllocation, err := proto.Marshal(allocationData)
if err != nil {
return s, err
}
pba := &pb.PayerBandwidthAllocation{
Data: serializedAllocation,
}
rr, err := ps.Get(ctx, derivedPieceID, pieceSize, pba, authorization) rr, err := ps.Get(ctx, derivedPieceID, pieceSize, pba, authorization)
if err != nil { if err != nil {
return s, err return s, err
@ -115,7 +100,7 @@ func (d *defaultDownloader) getShare(ctx context.Context, stripeIndex, shareSize
// Download Shares downloads shares from the nodes where remote pieces are located // Download Shares downloads shares from the nodes where remote pieces are located
func (d *defaultDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer, func (d *defaultDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer,
stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) { stripeIndex int, pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
var nodeIds storj.NodeIDList var nodeIds storj.NodeIDList
@ -142,7 +127,7 @@ func (d *defaultDownloader) DownloadShares(ctx context.Context, pointer *pb.Poin
paddedSize := calcPadded(pointer.GetSegmentSize(), shareSize) paddedSize := calcPadded(pointer.GetSegmentSize(), shareSize)
pieceSize := paddedSize / int64(pointer.Remote.Redundancy.GetMinReq()) pieceSize := paddedSize / int64(pointer.Remote.Redundancy.GetMinReq())
s, err := d.getShare(ctx, stripeIndex, shareSize, int(pieces[i].PieceNum), pieceID, pieceSize, node, authorization) s, err := d.getShare(ctx, stripeIndex, shareSize, int(pieces[i].PieceNum), pieceID, pieceSize, node, pba, authorization)
if err != nil { if err != nil {
s = share{ s = share{
Error: err, Error: err,
@ -207,10 +192,10 @@ func calcPadded(size int64, blockSize int) int64 {
} }
// verify downloads shares then verifies the data correctness at the given stripe // verify downloads shares then verifies the data correctness at the given stripe
func (verifier *Verifier) verify(ctx context.Context, stripeIndex int, pointer *pb.Pointer, authorization *pb.SignedMessage) (verifiedNodes *RecordAuditsInfo, err error) { func (verifier *Verifier) verify(ctx context.Context, stripe *Stripe) (verifiedNodes *RecordAuditsInfo, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
shares, nodes, err := verifier.downloader.DownloadShares(ctx, pointer, stripeIndex, authorization) shares, nodes, err := verifier.downloader.DownloadShares(ctx, stripe.Segment, stripe.Index, stripe.PBA, stripe.Authorization)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,6 +207,7 @@ func (verifier *Verifier) verify(ctx context.Context, stripeIndex int, pointer *
} }
} }
pointer := stripe.Segment
required := int(pointer.Remote.Redundancy.GetMinReq()) required := int(pointer.Remote.Redundancy.GetMinReq())
total := int(pointer.Remote.Redundancy.GetTotal()) total := int(pointer.Remote.Redundancy.GetTotal())
pieceNums, err := auditShares(ctx, required, total, shares) pieceNums, err := auditShares(ctx, required, total, shares)

View File

@ -44,7 +44,7 @@ func TestPassingAudit(t *testing.T) {
md := mockDownloader{shares: mockShares} md := mockDownloader{shares: mockShares}
verifier := &Verifier{downloader: &md} verifier := &Verifier{downloader: &md}
pointer := makePointer(tt.nodeAmt) pointer := makePointer(tt.nodeAmt)
verifiedNodes, err := verifier.verify(ctx, 6, pointer, nil) verifiedNodes, err := verifier.verify(ctx, &Stripe{Index: 6, Segment: pointer, PBA: nil, Authorization: nil})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -88,7 +88,7 @@ func TestSomeNodesPassAudit(t *testing.T) {
md := mockDownloader{shares: mockShares} md := mockDownloader{shares: mockShares}
verifier := &Verifier{downloader: &md} verifier := &Verifier{downloader: &md}
pointer := makePointer(tt.nodeAmt) pointer := makePointer(tt.nodeAmt)
verifiedNodes, err := verifier.verify(ctx, 6, pointer, nil) verifiedNodes, err := verifier.verify(ctx, &Stripe{Index: 6, Segment: pointer, PBA: nil, Authorization: nil})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -200,8 +200,8 @@ func TestCalcPadded(t *testing.T) {
} }
} }
func (m *mockDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer, func (m *mockDownloader) DownloadShares(ctx context.Context, pointer *pb.Pointer, stripeIndex int,
stripeIndex int, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) { pba *pb.PayerBandwidthAllocation, authorization *pb.SignedMessage) (shares map[int]share, nodes map[int]*pb.Node, err error) {
nodes = make(map[int]*pb.Node, 30) nodes = make(map[int]*pb.Node, 30)

View File

@ -72,6 +72,15 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer, bandwidt
return nil, err return nil, err
} }
pbaData := &pb.PayerBandwidthAllocation_Data{}
if err = proto.Unmarshal(deserializedData.GetPayerAllocation().GetData(), pbaData); err != nil {
return nil, err
}
if err = s.verifyPayerAllocation(pbaData, pb.PayerBandwidthAllocation_PUT); err != nil {
return nil, err
}
// Update bandwidthallocation to be stored // Update bandwidthallocation to be stored
if deserializedData.GetTotal() > sr.currentTotal { if deserializedData.GetTotal() > sr.currentTotal {
sr.bandwidthAllocation = ba sr.bandwidthAllocation = ba

View File

@ -140,6 +140,22 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re
return return
} }
if allocData.GetPayerAllocation() == nil {
allocationTracking.Fail(StoreError.New("no payer bandwidth allocation"))
return
}
pbaData := &pb.PayerBandwidthAllocation_Data{}
if err = proto.Unmarshal(allocData.GetPayerAllocation().GetData(), pbaData); err != nil {
allocationTracking.Fail(err)
return
}
if err = s.verifyPayerAllocation(pbaData, pb.PayerBandwidthAllocation_GET); err != nil {
allocationTracking.Fail(err)
return
}
// TODO: break when lastTotal >= allocData.GetPayer_allocation().GetData().GetMax_size() // TODO: break when lastTotal >= allocData.GetPayer_allocation().GetData().GetMax_size()
if lastTotal > allocData.GetTotal() { if lastTotal > allocData.GetTotal() {

View File

@ -308,6 +308,18 @@ func (s *Server) verifySignature(ctx context.Context, ba *pb.RenterBandwidthAllo
return nil return nil
} }
func (s *Server) verifyPayerAllocation(pba *pb.PayerBandwidthAllocation_Data, action pb.PayerBandwidthAllocation_Action) (err error) {
switch {
case pba.SatelliteId.IsZero():
return StoreError.New("payer bandwidth allocation: missing satellite id")
case pba.UplinkId.IsZero():
return StoreError.New("payer bandwidth allocation: missing uplink id")
case pba.Action != action:
return StoreError.New("payer bandwidth allocation: invalid action %v", pba.Action.String())
}
return nil
}
func getBeginningOfMonth() time.Time { func getBeginningOfMonth() time.Time {
t := time.Now() t := time.Now()
y, m, _ := t.Date() y, m, _ := t.Date()

View File

@ -29,6 +29,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"storj.io/storj/internal/testidentity" "storj.io/storj/internal/testidentity"
"storj.io/storj/internal/teststorj"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
pstore "storj.io/storj/pkg/piecestore" pstore "storj.io/storj/pkg/piecestore"
"storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/piecestore/psserver/psdb"
@ -340,12 +341,20 @@ func TestStore(t *testing.T) {
err = stream.Send(&pb.PieceStore{PieceData: &pb.PieceStore_PieceData{Id: tt.id, ExpirationUnixSec: tt.ttl}}) err = stream.Send(&pb.PieceStore{PieceData: &pb.PieceStore_PieceData{Id: tt.id, ExpirationUnixSec: tt.ttl}})
assert.NoError(err) assert.NoError(err)
pbad := &pb.PayerBandwidthAllocation_Data{
SatelliteId: teststorj.NodeIDFromString("satelliteid"),
UplinkId: teststorj.NodeIDFromString("uplinkid"),
Action: pb.PayerBandwidthAllocation_PUT,
}
pbaData, err := proto.Marshal(pbad)
assert.NoError(err)
pba := &pb.PayerBandwidthAllocation{Data: pbaData}
// Send Bandwidth Allocation Data // Send Bandwidth Allocation Data
msg := &pb.PieceStore{ msg := &pb.PieceStore{
PieceData: &pb.PieceStore_PieceData{Content: tt.content}, PieceData: &pb.PieceStore_PieceData{Content: tt.content},
BandwidthAllocation: &pb.RenterBandwidthAllocation{ BandwidthAllocation: &pb.RenterBandwidthAllocation{
Data: serializeData(&pb.RenterBandwidthAllocation_Data{ Data: serializeData(&pb.RenterBandwidthAllocation_Data{
PayerAllocation: &pb.PayerBandwidthAllocation{}, PayerAllocation: pba,
Total: int64(len(tt.content)), Total: int64(len(tt.content)),
}), }),
}, },
@ -394,7 +403,7 @@ func TestStore(t *testing.T) {
err = proto.Unmarshal(agreement, decoded) err = proto.Unmarshal(agreement, decoded)
assert.NoError(err) assert.NoError(err)
assert.Equal(msg.BandwidthAllocation.GetSignature(), signature) assert.Equal(msg.BandwidthAllocation.GetSignature(), signature)
assert.Equal(&pb.PayerBandwidthAllocation{}, decoded.GetPayerAllocation()) assert.True(proto.Equal(pba, decoded.GetPayerAllocation()))
assert.Equal(int64(len(tt.content)), decoded.GetTotal()) assert.Equal(int64(len(tt.content)), decoded.GetTotal())
} }
@ -407,6 +416,86 @@ func TestStore(t *testing.T) {
} }
} }
func TestPbaValidation(t *testing.T) {
TS := NewTestServer(t)
defer TS.Stop()
tests := []struct {
satelliteID storj.NodeID
uplinkID storj.NodeID
action pb.PayerBandwidthAllocation_Action
err string
}{
{ // missing satellite id
satelliteID: storj.NodeID{},
uplinkID: teststorj.NodeIDFromString("uplinkid"),
action: pb.PayerBandwidthAllocation_PUT,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id",
},
{ // missing uplink id
satelliteID: teststorj.NodeIDFromString("satelliteid"),
uplinkID: storj.NodeID{},
action: pb.PayerBandwidthAllocation_PUT,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id",
},
{ // wrong action type
satelliteID: teststorj.NodeIDFromString("satelliteid"),
uplinkID: teststorj.NodeIDFromString("uplinkid"),
action: pb.PayerBandwidthAllocation_GET,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET",
},
}
for _, tt := range tests {
t.Run("should validate payer bandwidth allocation struct", func(t *testing.T) {
assert := assert.New(t)
stream, err := TS.c.Store(ctx)
assert.NoError(err)
// Write the buffer to the stream we opened earlier
err = stream.Send(&pb.PieceStore{PieceData: &pb.PieceStore_PieceData{Id: "99999999999999999999", ExpirationUnixSec: 9999999999}})
assert.NoError(err)
pbad := &pb.PayerBandwidthAllocation_Data{
SatelliteId: tt.satelliteID,
UplinkId: tt.uplinkID,
Action: tt.action,
}
pbaData, err := proto.Marshal(pbad)
assert.NoError(err)
pba := &pb.PayerBandwidthAllocation{Data: pbaData}
// Send Bandwidth Allocation Data
content := []byte("content")
msg := &pb.PieceStore{
PieceData: &pb.PieceStore_PieceData{Content: content},
BandwidthAllocation: &pb.RenterBandwidthAllocation{
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
PayerAllocation: pba,
Total: int64(len(content)),
}),
},
}
s, err := cryptopasta.Sign(msg.BandwidthAllocation.Data, TS.k.(*ecdsa.PrivateKey))
assert.NoError(err)
msg.BandwidthAllocation.Signature = s
// Write the buffer to the stream we opened earlier
err = stream.Send(msg)
if err != io.EOF && err != nil {
assert.NoError(err)
}
_, err = stream.CloseAndRecv()
if err != nil {
//assert.NotNil(err)
assert.Equal(tt.err, err.Error())
return
}
})
}
}
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
TS := NewTestServer(t) TS := NewTestServer(t)
defer TS.Stop() defer TS.Stop()