Better metainfo Create/Commit request validation (#2088)

This commit is contained in:
Michal Niewrzal 2019-06-05 18:41:02 +02:00 committed by GitHub
parent b9d586901e
commit b5ac4f3eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 453 additions and 132 deletions

View File

@ -4,7 +4,6 @@
package metainfo
import (
"bytes"
"context"
"errors"
"strconv"
@ -18,7 +17,6 @@ import (
monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/pkg/accounting"
"storj.io/storj/pkg/auth"
"storj.io/storj/pkg/eestream"
"storj.io/storj/pkg/identity"
"storj.io/storj/pkg/macaroon"
@ -53,13 +51,14 @@ type Containment interface {
// Endpoint metainfo endpoint
type Endpoint struct {
log *zap.Logger
metainfo *Service
orders *orders.Service
cache *overlay.Cache
projectUsage *accounting.ProjectUsage
containment Containment
apiKeys APIKeys
log *zap.Logger
metainfo *Service
orders *orders.Service
cache *overlay.Cache
projectUsage *accounting.ProjectUsage
containment Containment
apiKeys APIKeys
createRequests *createRequests
}
// NewEndpoint creates new metainfo endpoint instance
@ -67,49 +66,20 @@ func NewEndpoint(log *zap.Logger, metainfo *Service, orders *orders.Service, cac
apiKeys APIKeys, projectUsage *accounting.ProjectUsage) *Endpoint {
// TODO do something with too many params
return &Endpoint{
log: log,
metainfo: metainfo,
orders: orders,
cache: cache,
containment: containment,
apiKeys: apiKeys,
projectUsage: projectUsage,
log: log,
metainfo: metainfo,
orders: orders,
cache: cache,
containment: containment,
apiKeys: apiKeys,
projectUsage: projectUsage,
createRequests: newCreateRequests(),
}
}
// Close closes resources
func (endpoint *Endpoint) Close() error { return nil }
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
defer mon.Task()(&ctx)(&err)
keyData, ok := auth.GetAPIKey(ctx)
if !ok {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
key, err := macaroon.ParseAPIKey(string(keyData))
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
// Revocations are currently handled by just deleting the key.
err = key.Check(ctx, keyInfo.Secret, action, nil)
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
return keyInfo, nil
}
// SegmentInfo returns segment metadata info
func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) {
defer mon.Task()(&ctx)(&err)
@ -209,6 +179,13 @@ func (endpoint *Endpoint) CreateSegment(ctx context.Context, req *pb.SegmentWrit
return nil, Error.Wrap(err)
}
if len(addressedLimits) > 0 {
endpoint.createRequests.Put(addressedLimits[0].Limit.SerialNumber, &createRequest{
Expiration: req.Expiration,
Redundancy: req.Redundancy,
})
}
return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil
}
@ -247,7 +224,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}
err = endpoint.validateCommit(ctx, req)
err = endpoint.validateCommitSegment(ctx, req)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
}
@ -288,6 +265,10 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
return nil, status.Errorf(codes.Internal, err.Error())
}
if len(req.OriginalLimits) > 0 {
endpoint.createRequests.Remove(req.OriginalLimits[0].SerialNumber)
}
return &pb.SegmentCommitResponse{Pointer: pointer}, nil
}
@ -503,78 +484,6 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
return nil
}
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
defer mon.Task()(&ctx)(&err)
if len(bucket) == 0 {
return errs.New("bucket not specified")
}
if bytes.ContainsAny(bucket, "/") {
return errs.New("bucket should not contain slash")
}
return nil
}
func (endpoint *Endpoint) validateCommit(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validatePointer(ctx, req.Pointer)
if err != nil {
return err
}
if req.Pointer.Type == pb.Pointer_REMOTE {
remote := req.Pointer.Remote
if len(req.OriginalLimits) == 0 {
return Error.New("no order limits")
}
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
return Error.New("invalid no order limit for piece")
}
for _, piece := range remote.RemotePieces {
limit := req.OriginalLimits[piece.PieceNum]
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
if err != nil {
return err
}
if limit == nil {
return Error.New("invalid no order limit for piece")
}
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
return Error.New("invalid order limit piece id")
}
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
return Error.New("piece NodeID != order limit NodeID")
}
}
}
return nil
}
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
defer mon.Task()(&ctx)(&err)
if pointer == nil {
return Error.New("no pointer specified")
}
// TODO does it all?
if pointer.Type == pb.Pointer_REMOTE {
if pointer.Remote == nil {
return Error.New("no remote segment specified")
}
if pointer.Remote.RemotePieces == nil {
return Error.New("no remote segment pieces specified")
}
if pointer.Remote.Redundancy == nil {
return Error.New("no redundancy scheme specified")
}
}
return nil
}
// CreatePath will create a Segment path
func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) {
defer mon.Task()(&ctx)(&err)
@ -597,12 +506,3 @@ func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bu
}
return storj.JoinPaths(entries...), nil
}
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
defer mon.Task()(&ctx)(&err)
// TODO more validation, use validation from eestream.NewRedundancyStrategy
if redundancy.ErasureShareSize <= 0 {
return Error.New("erasure share size cannot be less than 0")
}
return nil
}

View File

@ -9,19 +9,21 @@ import (
"testing"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"storj.io/storj/internal/memory"
"storj.io/storj/internal/testcontext"
"storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/macaroon"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
"storj.io/storj/satellite/console"
"storj.io/storj/uplink/metainfo"
)
// mockAPIKeys is mock for api keys store of pointerdb
@ -298,7 +300,8 @@ func TestCommitSegment(t *testing.T) {
Total: 6,
ErasureShareSize: 10,
}
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, time.Now())
expirationDate := time.Now()
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, expirationDate)
require.NoError(t, err)
// create number of pieces below repair threshold
@ -310,6 +313,10 @@ func TestCommitSegment(t *testing.T) {
NodeId: limit.Limit.StorageNodeId,
}
}
expirationDateProto, err := ptypes.TimestampProto(expirationDate)
require.NoError(t, err)
pointer := &pb.Pointer{
Type: pb.Pointer_REMOTE,
Remote: &pb.RemoteSegment{
@ -317,6 +324,7 @@ func TestCommitSegment(t *testing.T) {
Redundancy: redundancy,
RemotePieces: pieces,
},
ExpirationDate: expirationDateProto,
}
limits := make([]*pb.OrderLimit2, len(addresedLimits))
@ -329,3 +337,148 @@ func TestCommitSegment(t *testing.T) {
}
})
}
func TestDoubleCommitSegment(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
require.NoError(t, err)
pointer, limits := runCreateSegment(ctx, t, metainfo)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.NoError(t, err)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.Error(t, err)
require.Contains(t, err.Error(), "missing create request or request expired")
})
}
func TestCommitSegmentPointer(t *testing.T) {
// all tests needs to generate error
tests := []struct {
// defines how modify pointer before CommitSegment
Modify func(pointer *pb.Pointer)
ErrorMessage string
}{
{
Modify: func(pointer *pb.Pointer) {
pointer.ExpirationDate.Seconds += 100
},
ErrorMessage: "pointer expiration date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.MinReq += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.RepairThreshold += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.SuccessThreshold += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.Total += 100
},
// this error is triggered earlier then Create/Commit RS comparison
ErrorMessage: "invalid no order limit for piece",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.ErasureShareSize += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.Type = 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Type = pb.Pointer_INLINE
},
ErrorMessage: "pointer type is INLINE but remote segment is set",
},
}
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
require.NoError(t, err)
for _, test := range tests {
pointer, limits := runCreateSegment(ctx, t, metainfo)
test.Modify(pointer)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.Error(t, err)
require.Contains(t, err.Error(), test.ErrorMessage)
}
})
}
func runCreateSegment(ctx context.Context, t *testing.T, metainfo metainfo.Client) (*pb.Pointer, []*pb.OrderLimit2) {
pointer := createTestPointer(t)
expirationDate, err := ptypes.Timestamp(pointer.ExpirationDate)
require.NoError(t, err)
addressedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "myBucketName", "file/path", -1, pointer.Remote.Redundancy, memory.MiB.Int64(), expirationDate)
require.NoError(t, err)
pointer.Remote.RootPieceId = rootPieceID
pointer.Remote.RemotePieces[0].NodeId = addressedLimits[0].Limit.StorageNodeId
pointer.Remote.RemotePieces[1].NodeId = addressedLimits[1].Limit.StorageNodeId
limits := make([]*pb.OrderLimit2, len(addressedLimits))
for i, addressedLimit := range addressedLimits {
limits[i] = addressedLimit.Limit
}
return pointer, limits
}
func createTestPointer(t *testing.T) *pb.Pointer {
rs := &pb.RedundancyScheme{
MinReq: 1,
RepairThreshold: 1,
SuccessThreshold: 3,
Total: 4,
ErasureShareSize: 1024,
Type: pb.RedundancyScheme_RS,
}
pointer := &pb.Pointer{
Type: pb.Pointer_REMOTE,
Remote: &pb.RemoteSegment{
Redundancy: rs,
RemotePieces: []*pb.RemotePiece{
&pb.RemotePiece{
PieceNum: 0,
},
&pb.RemotePiece{
PieceNum: 1,
},
},
},
ExpirationDate: ptypes.TimestampNow(),
}
return pointer
}

View File

@ -0,0 +1,268 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package metainfo
import (
"bytes"
"context"
"sync"
"time"
"github.com/gogo/protobuf/proto"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/zeebo/errs"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"storj.io/storj/pkg/auth"
"storj.io/storj/pkg/macaroon"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
"storj.io/storj/satellite/console"
)
const requestTTL = time.Hour * 4
// TTLItem keeps association between serial number and ttl
type TTLItem struct {
serialNumber storj.SerialNumber
ttl time.Time
}
type createRequest struct {
Expiration *timestamp.Timestamp
Redundancy *pb.RedundancyScheme
ttl time.Time
}
type createRequests struct {
mu sync.RWMutex
// orders limit serial number used because with CreateSegment we don't have path yet
entries map[storj.SerialNumber]*createRequest
muTTL sync.Mutex
entriesTTL []*TTLItem
}
func newCreateRequests() *createRequests {
return &createRequests{
entries: make(map[storj.SerialNumber]*createRequest),
entriesTTL: make([]*TTLItem, 0),
}
}
func (requests *createRequests) Put(serialNumber storj.SerialNumber, createRequest *createRequest) {
ttl := time.Now().Add(requestTTL)
go func() {
requests.muTTL.Lock()
requests.entriesTTL = append(requests.entriesTTL, &TTLItem{
serialNumber: serialNumber,
ttl: ttl,
})
requests.muTTL.Unlock()
}()
createRequest.ttl = ttl
requests.mu.Lock()
requests.entries[serialNumber] = createRequest
requests.mu.Unlock()
go requests.cleanup()
}
func (requests *createRequests) Load(serialNumber storj.SerialNumber) (*createRequest, bool) {
requests.mu.RLock()
request, found := requests.entries[serialNumber]
if request != nil && request.ttl.Before(time.Now()) {
request = nil
found = false
}
requests.mu.RUnlock()
return request, found
}
func (requests *createRequests) Remove(serialNumber storj.SerialNumber) {
requests.mu.Lock()
delete(requests.entries, serialNumber)
requests.mu.Unlock()
}
func (requests *createRequests) cleanup() {
requests.muTTL.Lock()
now := time.Now()
remove := make([]storj.SerialNumber, 0)
newStart := 0
for i, item := range requests.entriesTTL {
if item.ttl.Before(now) {
remove = append(remove, item.serialNumber)
newStart = i + 1
} else {
break
}
}
requests.entriesTTL = requests.entriesTTL[newStart:]
requests.muTTL.Unlock()
for _, serialNumber := range remove {
requests.Remove(serialNumber)
}
}
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
defer mon.Task()(&ctx)(&err)
keyData, ok := auth.GetAPIKey(ctx)
if !ok {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
key, err := macaroon.ParseAPIKey(string(keyData))
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
// Revocations are currently handled by just deleting the key.
err = key.Check(ctx, keyInfo.Secret, action, nil)
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
return keyInfo, nil
}
func (endpoint *Endpoint) validateCreateSegment(ctx context.Context, req *pb.SegmentWriteRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validateBucket(ctx, req.Bucket)
if err != nil {
return err
}
err = endpoint.validateRedundancy(ctx, req.Redundancy)
if err != nil {
return err
}
return nil
}
func (endpoint *Endpoint) validateCommitSegment(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validateBucket(ctx, req.Bucket)
if err != nil {
return err
}
err = endpoint.validatePointer(ctx, req.Pointer)
if err != nil {
return err
}
if req.Pointer.Type == pb.Pointer_REMOTE {
remote := req.Pointer.Remote
if len(req.OriginalLimits) == 0 {
return Error.New("no order limits")
}
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
return Error.New("invalid no order limit for piece")
}
for _, piece := range remote.RemotePieces {
limit := req.OriginalLimits[piece.PieceNum]
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
if err != nil {
return err
}
if limit == nil {
return Error.New("invalid no order limit for piece")
}
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
return Error.New("invalid order limit piece id")
}
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
return Error.New("piece NodeID != order limit NodeID")
}
}
}
if len(req.OriginalLimits) > 0 {
createRequest, found := endpoint.createRequests.Load(req.OriginalLimits[0].SerialNumber)
switch {
case !found:
return Error.New("missing create request or request expired")
case !proto.Equal(createRequest.Expiration, req.Pointer.ExpirationDate):
return Error.New("pointer expiration date does not match requested one")
case !proto.Equal(createRequest.Redundancy, req.Pointer.Remote.Redundancy):
return Error.New("pointer redundancy scheme date does not match requested one")
}
}
return nil
}
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
defer mon.Task()(&ctx)(&err)
if len(bucket) == 0 {
return errs.New("bucket not specified")
}
if bytes.ContainsAny(bucket, "/") {
return errs.New("bucket should not contain slash")
}
return nil
}
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
defer mon.Task()(&ctx)(&err)
if pointer == nil {
return Error.New("no pointer specified")
}
if pointer.Type == pb.Pointer_INLINE && pointer.Remote != nil {
return Error.New("pointer type is INLINE but remote segment is set")
}
// TODO does it all?
if pointer.Type == pb.Pointer_REMOTE {
if pointer.Remote == nil {
return Error.New("no remote segment specified")
}
if pointer.Remote.RemotePieces == nil {
return Error.New("no remote segment pieces specified")
}
if pointer.Remote.Redundancy == nil {
return Error.New("no redundancy scheme specified")
}
}
return nil
}
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
defer mon.Task()(&ctx)(&err)
// TODO more validation, use validation from eestream.NewRedundancyStrategy
if redundancy.ErasureShareSize <= 0 {
return Error.New("erasure share size cannot be less than 0")
}
return nil
}