satellite/metainfo/pointerverification: service for verifying pointers

This implements a service for pointer verification. This makes the
slightly clearer, because it's not part of metainfo.

It also adds a peer identity cache which reduces database calls and peer
identity decoding.

Change-Id: I45da40460d579c6f5fd74c69bccea215157aafda
This commit is contained in:
Egon Elbre 2020-03-18 15:24:31 +02:00
parent 8597e6b512
commit eb1d8aab96
5 changed files with 333 additions and 191 deletions

View File

@ -18,7 +18,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"storj.io/common/errs2" "storj.io/common/errs2"
"storj.io/common/identity"
"storj.io/common/pb" "storj.io/common/pb"
"storj.io/common/rpc/rpcstatus" "storj.io/common/rpc/rpcstatus"
"storj.io/common/signing" "storj.io/common/signing"
@ -31,6 +30,7 @@ import (
"storj.io/storj/satellite/attribution" "storj.io/storj/satellite/attribution"
"storj.io/storj/satellite/console" "storj.io/storj/satellite/console"
"storj.io/storj/satellite/metainfo/piecedeletion" "storj.io/storj/satellite/metainfo/piecedeletion"
"storj.io/storj/satellite/metainfo/pointerverification"
"storj.io/storj/satellite/orders" "storj.io/storj/satellite/orders"
"storj.io/storj/satellite/overlay" "storj.io/storj/satellite/overlay"
"storj.io/storj/satellite/rewards" "storj.io/storj/satellite/rewards"
@ -39,7 +39,6 @@ import (
) )
const ( const (
pieceHashExpiration = 24 * time.Hour
satIDExpiration = 24 * time.Hour satIDExpiration = 24 * time.Hour
lastSegment = -1 lastSegment = -1
listLimit = 1000 listLimit = 1000
@ -80,7 +79,7 @@ type Endpoint struct {
overlay *overlay.Service overlay *overlay.Service
attributions attribution.DB attributions attribution.DB
partners *rewards.PartnersService partners *rewards.PartnersService
peerIdentities overlay.PeerIdentities pointerVerification *pointerverification.Service
projectUsage *accounting.Service projectUsage *accounting.Service
projects console.Projects projects console.Projects
apiKeys APIKeys apiKeys APIKeys
@ -108,7 +107,7 @@ func NewEndpoint(log *zap.Logger, metainfo *Service, deletePieces *piecedeletion
overlay: cache, overlay: cache,
attributions: attributions, attributions: attributions,
partners: partners, partners: partners,
peerIdentities: peerIdentities, pointerVerification: pointerverification.NewService(peerIdentities),
apiKeys: apiKeys, apiKeys: apiKeys,
projectUsage: projectUsage, projectUsage: projectUsage,
projects: projects, projects: projects,
@ -502,115 +501,33 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
return nil return nil
} }
// verify that the piece sizes matches what we would expect.
err = endpoint.pointerVerification.VerifySizes(ctx, pointer)
if err != nil {
endpoint.log.Debug("piece sizes are invalid", zap.Error(err))
return rpcstatus.Errorf(rpcstatus.InvalidArgument, "piece sizes are invalid: %v", err)
}
validPieces, invalidPieces, err := endpoint.pointerVerification.SelectValidPieces(ctx, pointer, originalLimits)
if err != nil {
endpoint.log.Debug("pointer verification failed", zap.Error(err))
return rpcstatus.Errorf(rpcstatus.InvalidArgument, "pointer verification failed: %s", err)
}
remote := pointer.Remote remote := pointer.Remote
peerIDMap, err := endpoint.mapNodesFor(ctx, remote.RemotePieces)
if err != nil {
return err
}
type invalidPiece struct {
NodeID storj.NodeID
PieceNum int32
Reason string
}
var (
remotePieces []*pb.RemotePiece
invalidPieces []invalidPiece
lastPieceSize int64
allSizesValid = true
)
for _, piece := range remote.RemotePieces {
// Verify storagenode signature on piecehash
peerID, ok := peerIDMap[piece.NodeId]
if !ok {
endpoint.log.Warn("Identity chain unknown for node. Piece removed from pointer",
zap.Stringer("Node ID", piece.NodeId),
zap.Int32("Piece ID", piece.PieceNum),
)
invalidPieces = append(invalidPieces, invalidPiece{
NodeID: piece.NodeId,
PieceNum: piece.PieceNum,
Reason: "Identity chain unknown for node",
})
continue
}
signee := signing.SigneeFromPeerIdentity(peerID)
limit := originalLimits[piece.PieceNum]
if limit == nil {
endpoint.log.Warn("There is not limit for the piece. Piece removed from pointer",
zap.Int32("Piece ID", piece.PieceNum),
)
invalidPieces = append(invalidPieces, invalidPiece{
NodeID: piece.NodeId,
PieceNum: piece.PieceNum,
Reason: "No order limit for validating the piece hash",
})
continue
}
err = endpoint.validatePieceHash(ctx, piece, limit, signee)
if err != nil {
endpoint.log.Warn("Problem validating piece hash. Pieces removed from pointer", zap.Error(err))
invalidPieces = append(invalidPieces, invalidPiece{
NodeID: piece.NodeId,
PieceNum: piece.PieceNum,
Reason: err.Error(),
})
continue
}
if piece.Hash.PieceSize <= 0 || (lastPieceSize > 0 && lastPieceSize != piece.Hash.PieceSize) {
allSizesValid = false
break
}
lastPieceSize = piece.Hash.PieceSize
remotePieces = append(remotePieces, piece)
}
if allSizesValid {
redundancy, err := eestream.NewRedundancyStrategyFromProto(pointer.GetRemote().GetRedundancy())
if err != nil {
endpoint.log.Debug("pointer contains an invalid redundancy strategy", zap.Error(Error.Wrap(err)))
return rpcstatus.Errorf(rpcstatus.InvalidArgument,
"invalid redundancy strategy; MinReq and/or Total are invalid: %s", err,
)
}
expectedPieceSize := eestream.CalcPieceSize(pointer.SegmentSize, redundancy)
if expectedPieceSize != lastPieceSize {
endpoint.log.Debug("expected piece size is different from provided",
zap.Int64("expectedSize", expectedPieceSize),
zap.Int64("actualSize", lastPieceSize),
)
return rpcstatus.Errorf(rpcstatus.InvalidArgument,
"expected piece size is different from provided (%d != %d)",
expectedPieceSize, lastPieceSize,
)
}
} else {
errMsg := "all pieces needs to have the same size"
endpoint.log.Debug(errMsg)
return rpcstatus.Error(rpcstatus.InvalidArgument, errMsg)
}
// We repair when the number of healthy files is less than or equal to the repair threshold // We repair when the number of healthy files is less than or equal to the repair threshold
// except for the case when the repair and success thresholds are the same (a case usually seen during testing). // except for the case when the repair and success thresholds are the same (a case usually seen during testing).
if numPieces := int32(len(remotePieces)); numPieces <= remote.Redundancy.RepairThreshold && numPieces < remote.Redundancy.SuccessThreshold { if numPieces := int32(len(validPieces)); numPieces <= remote.Redundancy.RepairThreshold && numPieces < remote.Redundancy.SuccessThreshold {
endpoint.log.Debug("Number of valid pieces is less than or equal to the repair threshold", endpoint.log.Debug("Number of valid pieces is less than or equal to the repair threshold",
zap.Int("totalReceivedPieces", len(remote.RemotePieces)), zap.Int("totalReceivedPieces", len(remote.RemotePieces)),
zap.Int("validPieces", len(remotePieces)), zap.Int("validPieces", len(validPieces)),
zap.Int("invalidPieces", len(invalidPieces)), zap.Int("invalidPieces", len(invalidPieces)),
zap.Int32("repairThreshold", remote.Redundancy.RepairThreshold), zap.Int32("repairThreshold", remote.Redundancy.RepairThreshold),
) )
errMsg := fmt.Sprintf("Number of valid pieces (%d) is less than or equal to the repair threshold (%d). Found %d invalid pieces", errMsg := fmt.Sprintf("Number of valid pieces (%d) is less than or equal to the repair threshold (%d). Found %d invalid pieces",
len(remotePieces), len(validPieces),
remote.Redundancy.RepairThreshold, remote.Redundancy.RepairThreshold,
len(remote.RemotePieces), len(remote.RemotePieces),
) )
@ -627,16 +544,16 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
return rpcstatus.Error(rpcstatus.InvalidArgument, errMsg) return rpcstatus.Error(rpcstatus.InvalidArgument, errMsg)
} }
if int32(len(remotePieces)) < remote.Redundancy.SuccessThreshold { if int32(len(validPieces)) < remote.Redundancy.SuccessThreshold {
endpoint.log.Debug("Number of valid pieces is less than the success threshold", endpoint.log.Debug("Number of valid pieces is less than the success threshold",
zap.Int("totalReceivedPieces", len(remote.RemotePieces)), zap.Int("totalReceivedPieces", len(remote.RemotePieces)),
zap.Int("validPieces", len(remotePieces)), zap.Int("validPieces", len(validPieces)),
zap.Int("invalidPieces", len(invalidPieces)), zap.Int("invalidPieces", len(invalidPieces)),
zap.Int32("successThreshold", remote.Redundancy.SuccessThreshold), zap.Int32("successThreshold", remote.Redundancy.SuccessThreshold),
) )
errMsg := fmt.Sprintf("Number of valid pieces (%d) is less than the success threshold (%d). Found %d invalid pieces", errMsg := fmt.Sprintf("Number of valid pieces (%d) is less than the success threshold (%d). Found %d invalid pieces",
len(remotePieces), len(validPieces),
remote.Redundancy.SuccessThreshold, remote.Redundancy.SuccessThreshold,
len(remote.RemotePieces), len(remote.RemotePieces),
) )
@ -653,29 +570,11 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
return rpcstatus.Error(rpcstatus.InvalidArgument, errMsg) return rpcstatus.Error(rpcstatus.InvalidArgument, errMsg)
} }
remote.RemotePieces = remotePieces remote.RemotePieces = validPieces
return nil return nil
} }
func (endpoint *Endpoint) mapNodesFor(ctx context.Context, pieces []*pb.RemotePiece) (map[storj.NodeID]*identity.PeerIdentity, error) {
nodeIDList := storj.NodeIDList{}
for _, piece := range pieces {
nodeIDList = append(nodeIDList, piece.NodeId)
}
peerIDList, err := endpoint.peerIdentities.BatchGet(ctx, nodeIDList)
if err != nil {
endpoint.log.Error("retrieving batch of the peer identities of nodes", zap.Error(Error.Wrap(err)))
return nil, rpcstatus.Error(rpcstatus.Internal, "retrieving nodes peer identities")
}
peerIDMap := make(map[storj.NodeID]*identity.PeerIdentity, len(peerIDList))
for _, peerID := range peerIDList {
peerIDMap[peerID.ID] = peerID
}
return peerIDMap, nil
}
// CreatePath creates a Segment path. // CreatePath creates a Segment path.
func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) { func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)

View File

@ -625,13 +625,13 @@ func TestCommitSegmentPointer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
pointer.Remote.RemotePieces[0].Hash = storageNodeHash pointer.Remote.RemotePieces[0].Hash = storageNodeHash
}, },
ErrorMessage: "all pieces needs to have the same size", ErrorMessage: "piece sizes are invalid",
}, },
{ {
Modify: func(ctx context.Context, pointer *pb.Pointer, _ map[storj.NodeID]*identity.FullIdentity, limits []*pb.OrderLimit) { Modify: func(ctx context.Context, pointer *pb.Pointer, _ map[storj.NodeID]*identity.FullIdentity, limits []*pb.OrderLimit) {
pointer.SegmentSize = 100 pointer.SegmentSize = 100
}, },
ErrorMessage: "expected piece size is different from provided", ErrorMessage: "piece sizes are invalid",
}, },
{ {
Modify: func(ctx context.Context, pointer *pb.Pointer, _ map[storj.NodeID]*identity.FullIdentity, limits []*pb.OrderLimit) { Modify: func(ctx context.Context, pointer *pb.Pointer, _ map[storj.NodeID]*identity.FullIdentity, limits []*pb.OrderLimit) {

View File

@ -0,0 +1,92 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package pointerverification
import (
"context"
"sync"
"storj.io/common/identity"
"storj.io/common/pb"
"storj.io/common/storj"
"storj.io/storj/satellite/overlay"
)
// IdentityCache implements caching of *identity.PeerIdentity.
type IdentityCache struct {
db overlay.PeerIdentities
mu sync.RWMutex
cached map[storj.NodeID]*identity.PeerIdentity
}
// NewIdentityCache returns an IdentityCache.
func NewIdentityCache(db overlay.PeerIdentities) *IdentityCache {
return &IdentityCache{
db: db,
cached: map[storj.NodeID]*identity.PeerIdentity{},
}
}
// GetCached returns the peer identity in the cache.
func (cache *IdentityCache) GetCached(ctx context.Context, id storj.NodeID) *identity.PeerIdentity {
defer mon.Task()(&ctx)(nil)
cache.mu.RLock()
defer cache.mu.RUnlock()
return cache.cached[id]
}
// GetUpdated returns the identity from database and updates the cache.
func (cache *IdentityCache) GetUpdated(ctx context.Context, id storj.NodeID) (_ *identity.PeerIdentity, err error) {
defer mon.Task()(&ctx)(&err)
identity, err := cache.db.Get(ctx, id)
if err != nil {
return nil, Error.Wrap(err)
}
cache.mu.Lock()
defer cache.mu.Unlock()
cache.cached[id] = identity
return identity, nil
}
// EnsureCached loads any missing identity into cache.
func (cache *IdentityCache) EnsureCached(ctx context.Context, pieces []*pb.RemotePiece) (err error) {
defer mon.Task()(&ctx)(&err)
missing := []storj.NodeID{}
cache.mu.RLock()
for _, piece := range pieces {
if _, ok := cache.cached[piece.NodeId]; !ok {
missing = append(missing, piece.NodeId)
}
}
cache.mu.RUnlock()
if len(missing) == 0 {
return nil
}
// There might be a race during updating, however we'll "reupdate" later if there's a failure.
// The common path doesn't end up here.
identities, err := cache.db.BatchGet(ctx, missing)
if err != nil {
return Error.Wrap(err)
}
cache.mu.Lock()
defer cache.mu.Unlock()
for _, identity := range identities {
cache.cached[identity.ID] = identity
}
return nil
}

View File

@ -0,0 +1,187 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// Package pointerverification implements verification of pointers.
package pointerverification
import (
"context"
"time"
"github.com/spacemonkeygo/monkit/v3"
"github.com/zeebo/errs"
"storj.io/common/pb"
"storj.io/common/signing"
"storj.io/common/storj"
"storj.io/storj/satellite/overlay"
"storj.io/uplink/private/eestream"
)
var (
mon = monkit.Package()
// Error general pointer verification error.
Error = errs.Class("pointer verification")
)
const pieceHashExpiration = 24 * time.Hour
// Service is a service for verifying validity of pieces.
type Service struct {
identities *IdentityCache
}
// NewService returns a service using the provided database.
func NewService(db overlay.PeerIdentities) *Service {
return &Service{
identities: NewIdentityCache(db),
}
}
// VerifySizes verifies that the remote piece sizes in pointer match each other.
func (service *Service) VerifySizes(ctx context.Context, pointer *pb.Pointer) (err error) {
defer mon.Task()(&ctx)(&err)
if pointer.Type != pb.Pointer_REMOTE {
return nil
}
commonSize := int64(-1)
for _, piece := range pointer.GetRemote().GetRemotePieces() {
if piece.Hash == nil {
continue
}
if piece.Hash.PieceSize <= 0 {
return Error.New("size is invalid (%d)", piece.Hash.PieceSize)
}
if commonSize > 0 && commonSize != piece.Hash.PieceSize {
return Error.New("sizes do not match (%d != %d)", commonSize, piece.Hash.PieceSize)
}
commonSize = piece.Hash.PieceSize
}
if commonSize < 0 {
return Error.New("no remote pieces")
}
redundancy, err := eestream.NewRedundancyStrategyFromProto(pointer.GetRemote().GetRedundancy())
if err != nil {
return Error.New("invalid redundancy strategy: %v", err)
}
expectedSize := eestream.CalcPieceSize(pointer.SegmentSize, redundancy)
if expectedSize != commonSize {
return Error.New("expected size is different from provided (%d != %d)", expectedSize, commonSize)
}
return nil
}
// InvalidPiece is information about an invalid piece in the pointer.
type InvalidPiece struct {
NodeID storj.NodeID
PieceNum int32
Reason error
}
// SelectValidPieces selects pieces that are have correct hashes and match the original limits.
func (service *Service) SelectValidPieces(ctx context.Context, pointer *pb.Pointer, originalLimits []*pb.OrderLimit) (valid []*pb.RemotePiece, invalid []InvalidPiece, err error) {
defer mon.Task()(&ctx)(&err)
err = service.identities.EnsureCached(ctx, pointer.GetRemote().GetRemotePieces())
if err != nil {
return nil, nil, Error.Wrap(err)
}
for _, piece := range pointer.GetRemote().GetRemotePieces() {
if int(piece.PieceNum) >= len(originalLimits) {
return nil, nil, Error.New("invalid piece number")
}
limit := originalLimits[piece.PieceNum]
if limit == nil {
return nil, nil, Error.New("limit missing for piece")
}
// verify that the piece id, serial number etc. match in piece and limit.
if err := VerifyPieceAndLimit(ctx, piece, limit); err != nil {
invalid = append(invalid, InvalidPiece{
NodeID: piece.NodeId,
PieceNum: piece.PieceNum,
Reason: err,
})
continue
}
peerIdentity := service.identities.GetCached(ctx, piece.NodeId)
if peerIdentity == nil {
// This shouldn't happen due to the caching in the start of the func.
return nil, nil, Error.New("nil identity returned (%v)", piece.NodeId)
}
signee := signing.SigneeFromPeerIdentity(peerIdentity)
// verify the signature
err = signing.VerifyPieceHashSignature(ctx, signee, piece.Hash)
if err != nil {
// TODO: check whether the identity changed from what it was before.
// Maybe the cache has gone stale?
peerIdentity, err := service.identities.GetUpdated(ctx, piece.NodeId)
if err != nil {
return nil, nil, Error.Wrap(err)
}
signee := signing.SigneeFromPeerIdentity(peerIdentity)
// let's check the signature again
err = signing.VerifyPieceHashSignature(ctx, signee, piece.Hash)
if err != nil {
invalid = append(invalid, InvalidPiece{
NodeID: piece.NodeId,
PieceNum: piece.PieceNum,
Reason: err,
})
continue
}
}
valid = append(valid, piece)
}
return valid, invalid, nil
}
// VerifyPieceAndLimit verifies that the piece and limit match.
func VerifyPieceAndLimit(ctx context.Context, piece *pb.RemotePiece, limit *pb.OrderLimit) (err error) {
defer mon.Task()(&ctx)(&err)
// ensure that we have a hash
if piece.Hash == nil {
return Error.New("no piece hash. NodeID: %v, PieceNum: %d", piece.NodeId, piece.PieceNum)
}
// verify the timestamp
timestamp := piece.Hash.Timestamp
if timestamp.Before(time.Now().Add(-pieceHashExpiration)) {
return Error.New("piece hash timestamp is too old (%v). NodeId: %v, PieceNum: %d)",
timestamp, piece.NodeId, piece.PieceNum,
)
}
// verify the piece id
if limit.PieceId != piece.Hash.PieceId {
return Error.New("piece hash pieceID (%v) doesn't match limit pieceID (%v). NodeID: %v, PieceNum: %d",
piece.Hash.PieceId, limit.PieceId, piece.NodeId, piece.PieceNum,
)
}
// verify the limit
if limit.Limit < piece.Hash.PieceSize {
return Error.New("piece hash PieceSize (%d) is larger than order limit (%d). NodeID: %v, PieceNum: %d",
piece.Hash.PieceSize, limit.Limit, piece.NodeId, piece.PieceNum,
)
}
return nil
}

View File

@ -20,7 +20,6 @@ import (
"storj.io/common/macaroon" "storj.io/common/macaroon"
"storj.io/common/pb" "storj.io/common/pb"
"storj.io/common/rpc/rpcstatus" "storj.io/common/rpc/rpcstatus"
"storj.io/common/signing"
"storj.io/common/storj" "storj.io/common/storj"
"storj.io/storj/pkg/auth" "storj.io/storj/pkg/auth"
"storj.io/storj/satellite/console" "storj.io/storj/satellite/console"
@ -409,38 +408,3 @@ func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb
return nil return nil
} }
func (endpoint *Endpoint) validatePieceHash(ctx context.Context, piece *pb.RemotePiece, originalLimit *pb.OrderLimit, signee signing.Signee) (err error) {
defer mon.Task()(&ctx)(&err)
if piece.Hash == nil {
return errs.New("no piece hash. NodeID: %v, PieceNum: %d", piece.NodeId, piece.PieceNum)
}
err = signing.VerifyPieceHashSignature(ctx, signee, piece.Hash)
if err != nil {
return errs.New("piece hash signature could not be verified for node (NodeID: %v, PieceNum: %d): %+v",
piece.NodeId, piece.PieceNum, err,
)
}
timestamp := piece.Hash.Timestamp
if timestamp.Before(time.Now().Add(-pieceHashExpiration)) {
return errs.New("piece hash timestamp is too old (%v). NodeId: %v, PieceNum: %d)",
timestamp, piece.NodeId, piece.PieceNum,
)
}
switch {
case originalLimit.PieceId != piece.Hash.PieceId:
return errs.New("piece hash pieceID (%v) doesn't match limit pieceID (%v). NodeID: %v, PieceNum: %d",
piece.Hash.PieceId, originalLimit.PieceId, piece.NodeId, piece.PieceNum,
)
case originalLimit.Limit < piece.Hash.PieceSize:
return errs.New("piece hash PieceSize (%d) is larger than order limit (%d). NodeID: %v, PieceNum: %d",
piece.Hash.PieceSize, originalLimit.Limit, piece.NodeId, piece.PieceNum,
)
}
return nil
}