Better metainfo Create/Commit request validation (#2088)

This commit is contained in:
Michal Niewrzal 2019-06-05 18:41:02 +02:00 committed by GitHub
parent b9d586901e
commit b5ac4f3eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 453 additions and 132 deletions

View File

@ -4,7 +4,6 @@
package metainfo package metainfo
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"strconv" "strconv"
@ -18,7 +17,6 @@ import (
monkit "gopkg.in/spacemonkeygo/monkit.v2" monkit "gopkg.in/spacemonkeygo/monkit.v2"
"storj.io/storj/pkg/accounting" "storj.io/storj/pkg/accounting"
"storj.io/storj/pkg/auth"
"storj.io/storj/pkg/eestream" "storj.io/storj/pkg/eestream"
"storj.io/storj/pkg/identity" "storj.io/storj/pkg/identity"
"storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/macaroon"
@ -60,6 +58,7 @@ type Endpoint struct {
projectUsage *accounting.ProjectUsage projectUsage *accounting.ProjectUsage
containment Containment containment Containment
apiKeys APIKeys apiKeys APIKeys
createRequests *createRequests
} }
// NewEndpoint creates new metainfo endpoint instance // NewEndpoint creates new metainfo endpoint instance
@ -74,42 +73,13 @@ func NewEndpoint(log *zap.Logger, metainfo *Service, orders *orders.Service, cac
containment: containment, containment: containment,
apiKeys: apiKeys, apiKeys: apiKeys,
projectUsage: projectUsage, projectUsage: projectUsage,
createRequests: newCreateRequests(),
} }
} }
// Close closes resources // Close closes resources
func (endpoint *Endpoint) Close() error { return nil } func (endpoint *Endpoint) Close() error { return nil }
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
defer mon.Task()(&ctx)(&err)
keyData, ok := auth.GetAPIKey(ctx)
if !ok {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
key, err := macaroon.ParseAPIKey(string(keyData))
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
// Revocations are currently handled by just deleting the key.
err = key.Check(ctx, keyInfo.Secret, action, nil)
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
return keyInfo, nil
}
// SegmentInfo returns segment metadata info // SegmentInfo returns segment metadata info
func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) { func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
@ -209,6 +179,13 @@ func (endpoint *Endpoint) CreateSegment(ctx context.Context, req *pb.SegmentWrit
return nil, Error.Wrap(err) return nil, Error.Wrap(err)
} }
if len(addressedLimits) > 0 {
endpoint.createRequests.Put(addressedLimits[0].Limit.SerialNumber, &createRequest{
Expiration: req.Expiration,
Redundancy: req.Redundancy,
})
}
return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil
} }
@ -247,7 +224,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
return nil, status.Errorf(codes.InvalidArgument, err.Error()) return nil, status.Errorf(codes.InvalidArgument, err.Error())
} }
err = endpoint.validateCommit(ctx, req) err = endpoint.validateCommitSegment(ctx, req)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, err.Error()) return nil, status.Errorf(codes.Internal, err.Error())
} }
@ -288,6 +265,10 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
return nil, status.Errorf(codes.Internal, err.Error()) return nil, status.Errorf(codes.Internal, err.Error())
} }
if len(req.OriginalLimits) > 0 {
endpoint.createRequests.Remove(req.OriginalLimits[0].SerialNumber)
}
return &pb.SegmentCommitResponse{Pointer: pointer}, nil return &pb.SegmentCommitResponse{Pointer: pointer}, nil
} }
@ -503,78 +484,6 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
return nil return nil
} }
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
defer mon.Task()(&ctx)(&err)
if len(bucket) == 0 {
return errs.New("bucket not specified")
}
if bytes.ContainsAny(bucket, "/") {
return errs.New("bucket should not contain slash")
}
return nil
}
func (endpoint *Endpoint) validateCommit(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validatePointer(ctx, req.Pointer)
if err != nil {
return err
}
if req.Pointer.Type == pb.Pointer_REMOTE {
remote := req.Pointer.Remote
if len(req.OriginalLimits) == 0 {
return Error.New("no order limits")
}
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
return Error.New("invalid no order limit for piece")
}
for _, piece := range remote.RemotePieces {
limit := req.OriginalLimits[piece.PieceNum]
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
if err != nil {
return err
}
if limit == nil {
return Error.New("invalid no order limit for piece")
}
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
return Error.New("invalid order limit piece id")
}
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
return Error.New("piece NodeID != order limit NodeID")
}
}
}
return nil
}
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
defer mon.Task()(&ctx)(&err)
if pointer == nil {
return Error.New("no pointer specified")
}
// TODO does it all?
if pointer.Type == pb.Pointer_REMOTE {
if pointer.Remote == nil {
return Error.New("no remote segment specified")
}
if pointer.Remote.RemotePieces == nil {
return Error.New("no remote segment pieces specified")
}
if pointer.Remote.Redundancy == nil {
return Error.New("no redundancy scheme specified")
}
}
return nil
}
// CreatePath will create a Segment path // CreatePath will create a Segment path
func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) { func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
@ -597,12 +506,3 @@ func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bu
} }
return storj.JoinPaths(entries...), nil return storj.JoinPaths(entries...), nil
} }
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
defer mon.Task()(&ctx)(&err)
// TODO more validation, use validation from eestream.NewRedundancyStrategy
if redundancy.ErasureShareSize <= 0 {
return Error.New("erasure share size cannot be less than 0")
}
return nil
}

View File

@ -9,19 +9,21 @@ import (
"testing" "testing"
"time" "time"
"google.golang.org/grpc/codes" "github.com/golang/protobuf/ptypes"
"google.golang.org/grpc/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"storj.io/storj/internal/memory"
"storj.io/storj/internal/testcontext" "storj.io/storj/internal/testcontext"
"storj.io/storj/internal/testplanet" "storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/macaroon"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj" "storj.io/storj/pkg/storj"
"storj.io/storj/satellite/console" "storj.io/storj/satellite/console"
"storj.io/storj/uplink/metainfo"
) )
// mockAPIKeys is mock for api keys store of pointerdb // mockAPIKeys is mock for api keys store of pointerdb
@ -298,7 +300,8 @@ func TestCommitSegment(t *testing.T) {
Total: 6, Total: 6,
ErasureShareSize: 10, ErasureShareSize: 10,
} }
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, time.Now()) expirationDate := time.Now()
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, expirationDate)
require.NoError(t, err) require.NoError(t, err)
// create number of pieces below repair threshold // create number of pieces below repair threshold
@ -310,6 +313,10 @@ func TestCommitSegment(t *testing.T) {
NodeId: limit.Limit.StorageNodeId, NodeId: limit.Limit.StorageNodeId,
} }
} }
expirationDateProto, err := ptypes.TimestampProto(expirationDate)
require.NoError(t, err)
pointer := &pb.Pointer{ pointer := &pb.Pointer{
Type: pb.Pointer_REMOTE, Type: pb.Pointer_REMOTE,
Remote: &pb.RemoteSegment{ Remote: &pb.RemoteSegment{
@ -317,6 +324,7 @@ func TestCommitSegment(t *testing.T) {
Redundancy: redundancy, Redundancy: redundancy,
RemotePieces: pieces, RemotePieces: pieces,
}, },
ExpirationDate: expirationDateProto,
} }
limits := make([]*pb.OrderLimit2, len(addresedLimits)) limits := make([]*pb.OrderLimit2, len(addresedLimits))
@ -329,3 +337,148 @@ func TestCommitSegment(t *testing.T) {
} }
}) })
} }
func TestDoubleCommitSegment(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
require.NoError(t, err)
pointer, limits := runCreateSegment(ctx, t, metainfo)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.NoError(t, err)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.Error(t, err)
require.Contains(t, err.Error(), "missing create request or request expired")
})
}
func TestCommitSegmentPointer(t *testing.T) {
// all tests needs to generate error
tests := []struct {
// defines how modify pointer before CommitSegment
Modify func(pointer *pb.Pointer)
ErrorMessage string
}{
{
Modify: func(pointer *pb.Pointer) {
pointer.ExpirationDate.Seconds += 100
},
ErrorMessage: "pointer expiration date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.MinReq += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.RepairThreshold += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.SuccessThreshold += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.Total += 100
},
// this error is triggered earlier then Create/Commit RS comparison
ErrorMessage: "invalid no order limit for piece",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.ErasureShareSize += 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Remote.Redundancy.Type = 100
},
ErrorMessage: "pointer redundancy scheme date does not match requested one",
},
{
Modify: func(pointer *pb.Pointer) {
pointer.Type = pb.Pointer_INLINE
},
ErrorMessage: "pointer type is INLINE but remote segment is set",
},
}
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
require.NoError(t, err)
for _, test := range tests {
pointer, limits := runCreateSegment(ctx, t, metainfo)
test.Modify(pointer)
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
require.Error(t, err)
require.Contains(t, err.Error(), test.ErrorMessage)
}
})
}
func runCreateSegment(ctx context.Context, t *testing.T, metainfo metainfo.Client) (*pb.Pointer, []*pb.OrderLimit2) {
pointer := createTestPointer(t)
expirationDate, err := ptypes.Timestamp(pointer.ExpirationDate)
require.NoError(t, err)
addressedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "myBucketName", "file/path", -1, pointer.Remote.Redundancy, memory.MiB.Int64(), expirationDate)
require.NoError(t, err)
pointer.Remote.RootPieceId = rootPieceID
pointer.Remote.RemotePieces[0].NodeId = addressedLimits[0].Limit.StorageNodeId
pointer.Remote.RemotePieces[1].NodeId = addressedLimits[1].Limit.StorageNodeId
limits := make([]*pb.OrderLimit2, len(addressedLimits))
for i, addressedLimit := range addressedLimits {
limits[i] = addressedLimit.Limit
}
return pointer, limits
}
func createTestPointer(t *testing.T) *pb.Pointer {
rs := &pb.RedundancyScheme{
MinReq: 1,
RepairThreshold: 1,
SuccessThreshold: 3,
Total: 4,
ErasureShareSize: 1024,
Type: pb.RedundancyScheme_RS,
}
pointer := &pb.Pointer{
Type: pb.Pointer_REMOTE,
Remote: &pb.RemoteSegment{
Redundancy: rs,
RemotePieces: []*pb.RemotePiece{
&pb.RemotePiece{
PieceNum: 0,
},
&pb.RemotePiece{
PieceNum: 1,
},
},
},
ExpirationDate: ptypes.TimestampNow(),
}
return pointer
}

View File

@ -0,0 +1,268 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package metainfo
import (
"bytes"
"context"
"sync"
"time"
"github.com/gogo/protobuf/proto"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/zeebo/errs"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"storj.io/storj/pkg/auth"
"storj.io/storj/pkg/macaroon"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
"storj.io/storj/satellite/console"
)
const requestTTL = time.Hour * 4
// TTLItem keeps association between serial number and ttl
type TTLItem struct {
serialNumber storj.SerialNumber
ttl time.Time
}
type createRequest struct {
Expiration *timestamp.Timestamp
Redundancy *pb.RedundancyScheme
ttl time.Time
}
type createRequests struct {
mu sync.RWMutex
// orders limit serial number used because with CreateSegment we don't have path yet
entries map[storj.SerialNumber]*createRequest
muTTL sync.Mutex
entriesTTL []*TTLItem
}
func newCreateRequests() *createRequests {
return &createRequests{
entries: make(map[storj.SerialNumber]*createRequest),
entriesTTL: make([]*TTLItem, 0),
}
}
func (requests *createRequests) Put(serialNumber storj.SerialNumber, createRequest *createRequest) {
ttl := time.Now().Add(requestTTL)
go func() {
requests.muTTL.Lock()
requests.entriesTTL = append(requests.entriesTTL, &TTLItem{
serialNumber: serialNumber,
ttl: ttl,
})
requests.muTTL.Unlock()
}()
createRequest.ttl = ttl
requests.mu.Lock()
requests.entries[serialNumber] = createRequest
requests.mu.Unlock()
go requests.cleanup()
}
func (requests *createRequests) Load(serialNumber storj.SerialNumber) (*createRequest, bool) {
requests.mu.RLock()
request, found := requests.entries[serialNumber]
if request != nil && request.ttl.Before(time.Now()) {
request = nil
found = false
}
requests.mu.RUnlock()
return request, found
}
func (requests *createRequests) Remove(serialNumber storj.SerialNumber) {
requests.mu.Lock()
delete(requests.entries, serialNumber)
requests.mu.Unlock()
}
func (requests *createRequests) cleanup() {
requests.muTTL.Lock()
now := time.Now()
remove := make([]storj.SerialNumber, 0)
newStart := 0
for i, item := range requests.entriesTTL {
if item.ttl.Before(now) {
remove = append(remove, item.serialNumber)
newStart = i + 1
} else {
break
}
}
requests.entriesTTL = requests.entriesTTL[newStart:]
requests.muTTL.Unlock()
for _, serialNumber := range remove {
requests.Remove(serialNumber)
}
}
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
defer mon.Task()(&ctx)(&err)
keyData, ok := auth.GetAPIKey(ctx)
if !ok {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
key, err := macaroon.ParseAPIKey(string(keyData))
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
// Revocations are currently handled by just deleting the key.
err = key.Check(ctx, keyInfo.Secret, action, nil)
if err != nil {
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
}
return keyInfo, nil
}
func (endpoint *Endpoint) validateCreateSegment(ctx context.Context, req *pb.SegmentWriteRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validateBucket(ctx, req.Bucket)
if err != nil {
return err
}
err = endpoint.validateRedundancy(ctx, req.Redundancy)
if err != nil {
return err
}
return nil
}
func (endpoint *Endpoint) validateCommitSegment(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
defer mon.Task()(&ctx)(&err)
err = endpoint.validateBucket(ctx, req.Bucket)
if err != nil {
return err
}
err = endpoint.validatePointer(ctx, req.Pointer)
if err != nil {
return err
}
if req.Pointer.Type == pb.Pointer_REMOTE {
remote := req.Pointer.Remote
if len(req.OriginalLimits) == 0 {
return Error.New("no order limits")
}
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
return Error.New("invalid no order limit for piece")
}
for _, piece := range remote.RemotePieces {
limit := req.OriginalLimits[piece.PieceNum]
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
if err != nil {
return err
}
if limit == nil {
return Error.New("invalid no order limit for piece")
}
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
return Error.New("invalid order limit piece id")
}
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
return Error.New("piece NodeID != order limit NodeID")
}
}
}
if len(req.OriginalLimits) > 0 {
createRequest, found := endpoint.createRequests.Load(req.OriginalLimits[0].SerialNumber)
switch {
case !found:
return Error.New("missing create request or request expired")
case !proto.Equal(createRequest.Expiration, req.Pointer.ExpirationDate):
return Error.New("pointer expiration date does not match requested one")
case !proto.Equal(createRequest.Redundancy, req.Pointer.Remote.Redundancy):
return Error.New("pointer redundancy scheme date does not match requested one")
}
}
return nil
}
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
defer mon.Task()(&ctx)(&err)
if len(bucket) == 0 {
return errs.New("bucket not specified")
}
if bytes.ContainsAny(bucket, "/") {
return errs.New("bucket should not contain slash")
}
return nil
}
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
defer mon.Task()(&ctx)(&err)
if pointer == nil {
return Error.New("no pointer specified")
}
if pointer.Type == pb.Pointer_INLINE && pointer.Remote != nil {
return Error.New("pointer type is INLINE but remote segment is set")
}
// TODO does it all?
if pointer.Type == pb.Pointer_REMOTE {
if pointer.Remote == nil {
return Error.New("no remote segment specified")
}
if pointer.Remote.RemotePieces == nil {
return Error.New("no remote segment pieces specified")
}
if pointer.Remote.Redundancy == nil {
return Error.New("no redundancy scheme specified")
}
}
return nil
}
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
defer mon.Task()(&ctx)(&err)
// TODO more validation, use validation from eestream.NewRedundancyStrategy
if redundancy.ErasureShareSize <= 0 {
return Error.New("erasure share size cannot be less than 0")
}
return nil
}