From b5ac4f3eac5d267805fb2062e722e99528f22132 Mon Sep 17 00:00:00 2001 From: Michal Niewrzal Date: Wed, 5 Jun 2019 18:41:02 +0200 Subject: [PATCH] Better metainfo Create/Commit request validation (#2088) --- satellite/metainfo/metainfo.go | 156 +++------------- satellite/metainfo/metainfo_test.go | 161 ++++++++++++++++- satellite/metainfo/validation.go | 268 ++++++++++++++++++++++++++++ 3 files changed, 453 insertions(+), 132 deletions(-) create mode 100644 satellite/metainfo/validation.go diff --git a/satellite/metainfo/metainfo.go b/satellite/metainfo/metainfo.go index 3893dcfde..d4ec46b07 100644 --- a/satellite/metainfo/metainfo.go +++ b/satellite/metainfo/metainfo.go @@ -4,7 +4,6 @@ package metainfo import ( - "bytes" "context" "errors" "strconv" @@ -18,7 +17,6 @@ import ( monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/accounting" - "storj.io/storj/pkg/auth" "storj.io/storj/pkg/eestream" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/macaroon" @@ -53,13 +51,14 @@ type Containment interface { // Endpoint metainfo endpoint type Endpoint struct { - log *zap.Logger - metainfo *Service - orders *orders.Service - cache *overlay.Cache - projectUsage *accounting.ProjectUsage - containment Containment - apiKeys APIKeys + log *zap.Logger + metainfo *Service + orders *orders.Service + cache *overlay.Cache + projectUsage *accounting.ProjectUsage + containment Containment + apiKeys APIKeys + createRequests *createRequests } // NewEndpoint creates new metainfo endpoint instance @@ -67,49 +66,20 @@ func NewEndpoint(log *zap.Logger, metainfo *Service, orders *orders.Service, cac apiKeys APIKeys, projectUsage *accounting.ProjectUsage) *Endpoint { // TODO do something with too many params return &Endpoint{ - log: log, - metainfo: metainfo, - orders: orders, - cache: cache, - containment: containment, - apiKeys: apiKeys, - projectUsage: projectUsage, + log: log, + metainfo: metainfo, + orders: orders, + cache: cache, + containment: containment, + apiKeys: apiKeys, + projectUsage: projectUsage, + createRequests: newCreateRequests(), } } // Close closes resources func (endpoint *Endpoint) Close() error { return nil } -func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) { - defer mon.Task()(&ctx)(&err) - keyData, ok := auth.GetAPIKey(ctx) - if !ok { - endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential"))) - return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") - } - - key, err := macaroon.ParseAPIKey(string(keyData)) - if err != nil { - endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential"))) - return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") - } - - keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head()) - if err != nil { - endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error()))) - return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") - } - - // Revocations are currently handled by just deleting the key. - err = key.Check(ctx, keyInfo.Secret, action, nil) - if err != nil { - endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error()))) - return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") - } - - return keyInfo, nil -} - // SegmentInfo returns segment metadata info func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) { defer mon.Task()(&ctx)(&err) @@ -209,6 +179,13 @@ func (endpoint *Endpoint) CreateSegment(ctx context.Context, req *pb.SegmentWrit return nil, Error.Wrap(err) } + if len(addressedLimits) > 0 { + endpoint.createRequests.Put(addressedLimits[0].Limit.SerialNumber, &createRequest{ + Expiration: req.Expiration, + Redundancy: req.Redundancy, + }) + } + return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil } @@ -247,7 +224,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm return nil, status.Errorf(codes.InvalidArgument, err.Error()) } - err = endpoint.validateCommit(ctx, req) + err = endpoint.validateCommitSegment(ctx, req) if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } @@ -288,6 +265,10 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm return nil, status.Errorf(codes.Internal, err.Error()) } + if len(req.OriginalLimits) > 0 { + endpoint.createRequests.Remove(req.OriginalLimits[0].SerialNumber) + } + return &pb.SegmentCommitResponse{Pointer: pointer}, nil } @@ -503,78 +484,6 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi return nil } -func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) { - defer mon.Task()(&ctx)(&err) - if len(bucket) == 0 { - return errs.New("bucket not specified") - } - if bytes.ContainsAny(bucket, "/") { - return errs.New("bucket should not contain slash") - } - return nil -} - -func (endpoint *Endpoint) validateCommit(ctx context.Context, req *pb.SegmentCommitRequest) (err error) { - defer mon.Task()(&ctx)(&err) - err = endpoint.validatePointer(ctx, req.Pointer) - if err != nil { - return err - } - - if req.Pointer.Type == pb.Pointer_REMOTE { - remote := req.Pointer.Remote - - if len(req.OriginalLimits) == 0 { - return Error.New("no order limits") - } - if int32(len(req.OriginalLimits)) != remote.Redundancy.Total { - return Error.New("invalid no order limit for piece") - } - - for _, piece := range remote.RemotePieces { - limit := req.OriginalLimits[piece.PieceNum] - - err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit) - if err != nil { - return err - } - - if limit == nil { - return Error.New("invalid no order limit for piece") - } - derivedPieceID := remote.RootPieceId.Derive(piece.NodeId) - if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID { - return Error.New("invalid order limit piece id") - } - if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 { - return Error.New("piece NodeID != order limit NodeID") - } - } - } - return nil -} - -func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) { - defer mon.Task()(&ctx)(&err) - if pointer == nil { - return Error.New("no pointer specified") - } - - // TODO does it all? - if pointer.Type == pb.Pointer_REMOTE { - if pointer.Remote == nil { - return Error.New("no remote segment specified") - } - if pointer.Remote.RemotePieces == nil { - return Error.New("no remote segment pieces specified") - } - if pointer.Remote.Redundancy == nil { - return Error.New("no redundancy scheme specified") - } - } - return nil -} - // CreatePath will create a Segment path func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) { defer mon.Task()(&ctx)(&err) @@ -597,12 +506,3 @@ func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bu } return storj.JoinPaths(entries...), nil } - -func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) { - defer mon.Task()(&ctx)(&err) - // TODO more validation, use validation from eestream.NewRedundancyStrategy - if redundancy.ErasureShareSize <= 0 { - return Error.New("erasure share size cannot be less than 0") - } - return nil -} diff --git a/satellite/metainfo/metainfo_test.go b/satellite/metainfo/metainfo_test.go index 6a651e080..c4fb70658 100644 --- a/satellite/metainfo/metainfo_test.go +++ b/satellite/metainfo/metainfo_test.go @@ -9,19 +9,21 @@ import ( "testing" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - + "github.com/golang/protobuf/ptypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zeebo/errs" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "storj.io/storj/internal/memory" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testplanet" "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/storj" "storj.io/storj/satellite/console" + "storj.io/storj/uplink/metainfo" ) // mockAPIKeys is mock for api keys store of pointerdb @@ -298,7 +300,8 @@ func TestCommitSegment(t *testing.T) { Total: 6, ErasureShareSize: 10, } - addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, time.Now()) + expirationDate := time.Now() + addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, expirationDate) require.NoError(t, err) // create number of pieces below repair threshold @@ -310,6 +313,10 @@ func TestCommitSegment(t *testing.T) { NodeId: limit.Limit.StorageNodeId, } } + + expirationDateProto, err := ptypes.TimestampProto(expirationDate) + require.NoError(t, err) + pointer := &pb.Pointer{ Type: pb.Pointer_REMOTE, Remote: &pb.RemoteSegment{ @@ -317,6 +324,7 @@ func TestCommitSegment(t *testing.T) { Redundancy: redundancy, RemotePieces: pieces, }, + ExpirationDate: expirationDateProto, } limits := make([]*pb.OrderLimit2, len(addresedLimits)) @@ -329,3 +337,148 @@ func TestCommitSegment(t *testing.T) { } }) } + +func TestDoubleCommitSegment(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()] + + metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey) + require.NoError(t, err) + + pointer, limits := runCreateSegment(ctx, t, metainfo) + + _, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits) + require.NoError(t, err) + + _, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits) + require.Error(t, err) + require.Contains(t, err.Error(), "missing create request or request expired") + }) +} + +func TestCommitSegmentPointer(t *testing.T) { + // all tests needs to generate error + tests := []struct { + // defines how modify pointer before CommitSegment + Modify func(pointer *pb.Pointer) + ErrorMessage string + }{ + { + Modify: func(pointer *pb.Pointer) { + pointer.ExpirationDate.Seconds += 100 + }, + ErrorMessage: "pointer expiration date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.MinReq += 100 + }, + ErrorMessage: "pointer redundancy scheme date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.RepairThreshold += 100 + }, + ErrorMessage: "pointer redundancy scheme date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.SuccessThreshold += 100 + }, + ErrorMessage: "pointer redundancy scheme date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.Total += 100 + }, + // this error is triggered earlier then Create/Commit RS comparison + ErrorMessage: "invalid no order limit for piece", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.ErasureShareSize += 100 + }, + ErrorMessage: "pointer redundancy scheme date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Remote.Redundancy.Type = 100 + }, + ErrorMessage: "pointer redundancy scheme date does not match requested one", + }, + { + Modify: func(pointer *pb.Pointer) { + pointer.Type = pb.Pointer_INLINE + }, + ErrorMessage: "pointer type is INLINE but remote segment is set", + }, + } + + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()] + + metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey) + require.NoError(t, err) + + for _, test := range tests { + pointer, limits := runCreateSegment(ctx, t, metainfo) + test.Modify(pointer) + + _, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits) + require.Error(t, err) + require.Contains(t, err.Error(), test.ErrorMessage) + } + }) +} + +func runCreateSegment(ctx context.Context, t *testing.T, metainfo metainfo.Client) (*pb.Pointer, []*pb.OrderLimit2) { + pointer := createTestPointer(t) + expirationDate, err := ptypes.Timestamp(pointer.ExpirationDate) + require.NoError(t, err) + + addressedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "myBucketName", "file/path", -1, pointer.Remote.Redundancy, memory.MiB.Int64(), expirationDate) + require.NoError(t, err) + + pointer.Remote.RootPieceId = rootPieceID + pointer.Remote.RemotePieces[0].NodeId = addressedLimits[0].Limit.StorageNodeId + pointer.Remote.RemotePieces[1].NodeId = addressedLimits[1].Limit.StorageNodeId + + limits := make([]*pb.OrderLimit2, len(addressedLimits)) + for i, addressedLimit := range addressedLimits { + limits[i] = addressedLimit.Limit + } + + return pointer, limits +} + +func createTestPointer(t *testing.T) *pb.Pointer { + rs := &pb.RedundancyScheme{ + MinReq: 1, + RepairThreshold: 1, + SuccessThreshold: 3, + Total: 4, + ErasureShareSize: 1024, + Type: pb.RedundancyScheme_RS, + } + + pointer := &pb.Pointer{ + Type: pb.Pointer_REMOTE, + Remote: &pb.RemoteSegment{ + Redundancy: rs, + RemotePieces: []*pb.RemotePiece{ + &pb.RemotePiece{ + PieceNum: 0, + }, + &pb.RemotePiece{ + PieceNum: 1, + }, + }, + }, + ExpirationDate: ptypes.TimestampNow(), + } + return pointer +} diff --git a/satellite/metainfo/validation.go b/satellite/metainfo/validation.go new file mode 100644 index 000000000..cace6dd51 --- /dev/null +++ b/satellite/metainfo/validation.go @@ -0,0 +1,268 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package metainfo + +import ( + "bytes" + "context" + "sync" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/zeebo/errs" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "storj.io/storj/pkg/auth" + "storj.io/storj/pkg/macaroon" + "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/storj" + "storj.io/storj/satellite/console" +) + +const requestTTL = time.Hour * 4 + +// TTLItem keeps association between serial number and ttl +type TTLItem struct { + serialNumber storj.SerialNumber + ttl time.Time +} + +type createRequest struct { + Expiration *timestamp.Timestamp + Redundancy *pb.RedundancyScheme + + ttl time.Time +} + +type createRequests struct { + mu sync.RWMutex + // orders limit serial number used because with CreateSegment we don't have path yet + entries map[storj.SerialNumber]*createRequest + + muTTL sync.Mutex + entriesTTL []*TTLItem +} + +func newCreateRequests() *createRequests { + return &createRequests{ + entries: make(map[storj.SerialNumber]*createRequest), + entriesTTL: make([]*TTLItem, 0), + } +} + +func (requests *createRequests) Put(serialNumber storj.SerialNumber, createRequest *createRequest) { + ttl := time.Now().Add(requestTTL) + + go func() { + requests.muTTL.Lock() + requests.entriesTTL = append(requests.entriesTTL, &TTLItem{ + serialNumber: serialNumber, + ttl: ttl, + }) + requests.muTTL.Unlock() + }() + + createRequest.ttl = ttl + requests.mu.Lock() + requests.entries[serialNumber] = createRequest + requests.mu.Unlock() + + go requests.cleanup() +} + +func (requests *createRequests) Load(serialNumber storj.SerialNumber) (*createRequest, bool) { + requests.mu.RLock() + request, found := requests.entries[serialNumber] + if request != nil && request.ttl.Before(time.Now()) { + request = nil + found = false + } + requests.mu.RUnlock() + + return request, found +} + +func (requests *createRequests) Remove(serialNumber storj.SerialNumber) { + requests.mu.Lock() + delete(requests.entries, serialNumber) + requests.mu.Unlock() +} + +func (requests *createRequests) cleanup() { + requests.muTTL.Lock() + now := time.Now() + remove := make([]storj.SerialNumber, 0) + newStart := 0 + for i, item := range requests.entriesTTL { + if item.ttl.Before(now) { + remove = append(remove, item.serialNumber) + newStart = i + 1 + } else { + break + } + } + requests.entriesTTL = requests.entriesTTL[newStart:] + requests.muTTL.Unlock() + + for _, serialNumber := range remove { + requests.Remove(serialNumber) + } +} + +func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) { + defer mon.Task()(&ctx)(&err) + keyData, ok := auth.GetAPIKey(ctx) + if !ok { + endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential"))) + return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") + } + + key, err := macaroon.ParseAPIKey(string(keyData)) + if err != nil { + endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential"))) + return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") + } + + keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head()) + if err != nil { + endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error()))) + return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") + } + + // Revocations are currently handled by just deleting the key. + err = key.Check(ctx, keyInfo.Secret, action, nil) + if err != nil { + endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error()))) + return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential") + } + + return keyInfo, nil +} + +func (endpoint *Endpoint) validateCreateSegment(ctx context.Context, req *pb.SegmentWriteRequest) (err error) { + defer mon.Task()(&ctx)(&err) + + err = endpoint.validateBucket(ctx, req.Bucket) + if err != nil { + return err + } + + err = endpoint.validateRedundancy(ctx, req.Redundancy) + if err != nil { + return err + } + + return nil +} + +func (endpoint *Endpoint) validateCommitSegment(ctx context.Context, req *pb.SegmentCommitRequest) (err error) { + defer mon.Task()(&ctx)(&err) + + err = endpoint.validateBucket(ctx, req.Bucket) + if err != nil { + return err + } + + err = endpoint.validatePointer(ctx, req.Pointer) + if err != nil { + return err + } + + if req.Pointer.Type == pb.Pointer_REMOTE { + remote := req.Pointer.Remote + + if len(req.OriginalLimits) == 0 { + return Error.New("no order limits") + } + if int32(len(req.OriginalLimits)) != remote.Redundancy.Total { + return Error.New("invalid no order limit for piece") + } + + for _, piece := range remote.RemotePieces { + limit := req.OriginalLimits[piece.PieceNum] + + err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit) + if err != nil { + return err + } + + if limit == nil { + return Error.New("invalid no order limit for piece") + } + derivedPieceID := remote.RootPieceId.Derive(piece.NodeId) + if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID { + return Error.New("invalid order limit piece id") + } + if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 { + return Error.New("piece NodeID != order limit NodeID") + } + } + } + + if len(req.OriginalLimits) > 0 { + createRequest, found := endpoint.createRequests.Load(req.OriginalLimits[0].SerialNumber) + + switch { + case !found: + return Error.New("missing create request or request expired") + case !proto.Equal(createRequest.Expiration, req.Pointer.ExpirationDate): + return Error.New("pointer expiration date does not match requested one") + case !proto.Equal(createRequest.Redundancy, req.Pointer.Remote.Redundancy): + return Error.New("pointer redundancy scheme date does not match requested one") + } + } + + return nil +} + +func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) { + defer mon.Task()(&ctx)(&err) + + if len(bucket) == 0 { + return errs.New("bucket not specified") + } + if bytes.ContainsAny(bucket, "/") { + return errs.New("bucket should not contain slash") + } + return nil +} + +func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) { + defer mon.Task()(&ctx)(&err) + + if pointer == nil { + return Error.New("no pointer specified") + } + + if pointer.Type == pb.Pointer_INLINE && pointer.Remote != nil { + return Error.New("pointer type is INLINE but remote segment is set") + } + + // TODO does it all? + if pointer.Type == pb.Pointer_REMOTE { + if pointer.Remote == nil { + return Error.New("no remote segment specified") + } + if pointer.Remote.RemotePieces == nil { + return Error.New("no remote segment pieces specified") + } + if pointer.Remote.Redundancy == nil { + return Error.New("no redundancy scheme specified") + } + } + return nil +} + +func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) { + defer mon.Task()(&ctx)(&err) + + // TODO more validation, use validation from eestream.NewRedundancyStrategy + if redundancy.ErasureShareSize <= 0 { + return Error.New("erasure share size cannot be less than 0") + } + return nil +}