Better metainfo Create/Commit request validation (#2088)
This commit is contained in:
parent
b9d586901e
commit
b5ac4f3eac
@ -4,7 +4,6 @@
|
|||||||
package metainfo
|
package metainfo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -18,7 +17,6 @@ import (
|
|||||||
monkit "gopkg.in/spacemonkeygo/monkit.v2"
|
monkit "gopkg.in/spacemonkeygo/monkit.v2"
|
||||||
|
|
||||||
"storj.io/storj/pkg/accounting"
|
"storj.io/storj/pkg/accounting"
|
||||||
"storj.io/storj/pkg/auth"
|
|
||||||
"storj.io/storj/pkg/eestream"
|
"storj.io/storj/pkg/eestream"
|
||||||
"storj.io/storj/pkg/identity"
|
"storj.io/storj/pkg/identity"
|
||||||
"storj.io/storj/pkg/macaroon"
|
"storj.io/storj/pkg/macaroon"
|
||||||
@ -60,6 +58,7 @@ type Endpoint struct {
|
|||||||
projectUsage *accounting.ProjectUsage
|
projectUsage *accounting.ProjectUsage
|
||||||
containment Containment
|
containment Containment
|
||||||
apiKeys APIKeys
|
apiKeys APIKeys
|
||||||
|
createRequests *createRequests
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEndpoint creates new metainfo endpoint instance
|
// NewEndpoint creates new metainfo endpoint instance
|
||||||
@ -74,42 +73,13 @@ func NewEndpoint(log *zap.Logger, metainfo *Service, orders *orders.Service, cac
|
|||||||
containment: containment,
|
containment: containment,
|
||||||
apiKeys: apiKeys,
|
apiKeys: apiKeys,
|
||||||
projectUsage: projectUsage,
|
projectUsage: projectUsage,
|
||||||
|
createRequests: newCreateRequests(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes resources
|
// Close closes resources
|
||||||
func (endpoint *Endpoint) Close() error { return nil }
|
func (endpoint *Endpoint) Close() error { return nil }
|
||||||
|
|
||||||
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
|
|
||||||
defer mon.Task()(&ctx)(&err)
|
|
||||||
keyData, ok := auth.GetAPIKey(ctx)
|
|
||||||
if !ok {
|
|
||||||
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := macaroon.ParseAPIKey(string(keyData))
|
|
||||||
if err != nil {
|
|
||||||
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
|
||||||
}
|
|
||||||
|
|
||||||
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
|
|
||||||
if err != nil {
|
|
||||||
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Revocations are currently handled by just deleting the key.
|
|
||||||
err = key.Check(ctx, keyInfo.Secret, action, nil)
|
|
||||||
if err != nil {
|
|
||||||
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
|
||||||
}
|
|
||||||
|
|
||||||
return keyInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SegmentInfo returns segment metadata info
|
// SegmentInfo returns segment metadata info
|
||||||
func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) {
|
func (endpoint *Endpoint) SegmentInfo(ctx context.Context, req *pb.SegmentInfoRequest) (resp *pb.SegmentInfoResponse, err error) {
|
||||||
defer mon.Task()(&ctx)(&err)
|
defer mon.Task()(&ctx)(&err)
|
||||||
@ -209,6 +179,13 @@ func (endpoint *Endpoint) CreateSegment(ctx context.Context, req *pb.SegmentWrit
|
|||||||
return nil, Error.Wrap(err)
|
return nil, Error.Wrap(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(addressedLimits) > 0 {
|
||||||
|
endpoint.createRequests.Put(addressedLimits[0].Limit.SerialNumber, &createRequest{
|
||||||
|
Expiration: req.Expiration,
|
||||||
|
Redundancy: req.Redundancy,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil
|
return &pb.SegmentWriteResponse{AddressedLimits: addressedLimits, RootPieceId: rootPieceID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +224,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
|
|||||||
return nil, status.Errorf(codes.InvalidArgument, err.Error())
|
return nil, status.Errorf(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = endpoint.validateCommit(ctx, req)
|
err = endpoint.validateCommitSegment(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, err.Error())
|
return nil, status.Errorf(codes.Internal, err.Error())
|
||||||
}
|
}
|
||||||
@ -288,6 +265,10 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm
|
|||||||
return nil, status.Errorf(codes.Internal, err.Error())
|
return nil, status.Errorf(codes.Internal, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(req.OriginalLimits) > 0 {
|
||||||
|
endpoint.createRequests.Remove(req.OriginalLimits[0].SerialNumber)
|
||||||
|
}
|
||||||
|
|
||||||
return &pb.SegmentCommitResponse{Pointer: pointer}, nil
|
return &pb.SegmentCommitResponse{Pointer: pointer}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -503,78 +484,6 @@ func (endpoint *Endpoint) filterValidPieces(ctx context.Context, pointer *pb.Poi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
|
|
||||||
defer mon.Task()(&ctx)(&err)
|
|
||||||
if len(bucket) == 0 {
|
|
||||||
return errs.New("bucket not specified")
|
|
||||||
}
|
|
||||||
if bytes.ContainsAny(bucket, "/") {
|
|
||||||
return errs.New("bucket should not contain slash")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *Endpoint) validateCommit(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
|
|
||||||
defer mon.Task()(&ctx)(&err)
|
|
||||||
err = endpoint.validatePointer(ctx, req.Pointer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Pointer.Type == pb.Pointer_REMOTE {
|
|
||||||
remote := req.Pointer.Remote
|
|
||||||
|
|
||||||
if len(req.OriginalLimits) == 0 {
|
|
||||||
return Error.New("no order limits")
|
|
||||||
}
|
|
||||||
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
|
|
||||||
return Error.New("invalid no order limit for piece")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, piece := range remote.RemotePieces {
|
|
||||||
limit := req.OriginalLimits[piece.PieceNum]
|
|
||||||
|
|
||||||
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if limit == nil {
|
|
||||||
return Error.New("invalid no order limit for piece")
|
|
||||||
}
|
|
||||||
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
|
|
||||||
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
|
|
||||||
return Error.New("invalid order limit piece id")
|
|
||||||
}
|
|
||||||
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
|
|
||||||
return Error.New("piece NodeID != order limit NodeID")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
|
|
||||||
defer mon.Task()(&ctx)(&err)
|
|
||||||
if pointer == nil {
|
|
||||||
return Error.New("no pointer specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO does it all?
|
|
||||||
if pointer.Type == pb.Pointer_REMOTE {
|
|
||||||
if pointer.Remote == nil {
|
|
||||||
return Error.New("no remote segment specified")
|
|
||||||
}
|
|
||||||
if pointer.Remote.RemotePieces == nil {
|
|
||||||
return Error.New("no remote segment pieces specified")
|
|
||||||
}
|
|
||||||
if pointer.Remote.Redundancy == nil {
|
|
||||||
return Error.New("no redundancy scheme specified")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePath will create a Segment path
|
// CreatePath will create a Segment path
|
||||||
func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) {
|
func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, path []byte) (_ storj.Path, err error) {
|
||||||
defer mon.Task()(&ctx)(&err)
|
defer mon.Task()(&ctx)(&err)
|
||||||
@ -597,12 +506,3 @@ func CreatePath(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bu
|
|||||||
}
|
}
|
||||||
return storj.JoinPaths(entries...), nil
|
return storj.JoinPaths(entries...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
|
|
||||||
defer mon.Task()(&ctx)(&err)
|
|
||||||
// TODO more validation, use validation from eestream.NewRedundancyStrategy
|
|
||||||
if redundancy.ErasureShareSize <= 0 {
|
|
||||||
return Error.New("erasure share size cannot be less than 0")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -9,19 +9,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"storj.io/storj/internal/memory"
|
||||||
"storj.io/storj/internal/testcontext"
|
"storj.io/storj/internal/testcontext"
|
||||||
"storj.io/storj/internal/testplanet"
|
"storj.io/storj/internal/testplanet"
|
||||||
"storj.io/storj/pkg/macaroon"
|
"storj.io/storj/pkg/macaroon"
|
||||||
"storj.io/storj/pkg/pb"
|
"storj.io/storj/pkg/pb"
|
||||||
"storj.io/storj/pkg/storj"
|
"storj.io/storj/pkg/storj"
|
||||||
"storj.io/storj/satellite/console"
|
"storj.io/storj/satellite/console"
|
||||||
|
"storj.io/storj/uplink/metainfo"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockAPIKeys is mock for api keys store of pointerdb
|
// mockAPIKeys is mock for api keys store of pointerdb
|
||||||
@ -298,7 +300,8 @@ func TestCommitSegment(t *testing.T) {
|
|||||||
Total: 6,
|
Total: 6,
|
||||||
ErasureShareSize: 10,
|
ErasureShareSize: 10,
|
||||||
}
|
}
|
||||||
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, time.Now())
|
expirationDate := time.Now()
|
||||||
|
addresedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "bucket", "path", -1, redundancy, 1000, expirationDate)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create number of pieces below repair threshold
|
// create number of pieces below repair threshold
|
||||||
@ -310,6 +313,10 @@ func TestCommitSegment(t *testing.T) {
|
|||||||
NodeId: limit.Limit.StorageNodeId,
|
NodeId: limit.Limit.StorageNodeId,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expirationDateProto, err := ptypes.TimestampProto(expirationDate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
pointer := &pb.Pointer{
|
pointer := &pb.Pointer{
|
||||||
Type: pb.Pointer_REMOTE,
|
Type: pb.Pointer_REMOTE,
|
||||||
Remote: &pb.RemoteSegment{
|
Remote: &pb.RemoteSegment{
|
||||||
@ -317,6 +324,7 @@ func TestCommitSegment(t *testing.T) {
|
|||||||
Redundancy: redundancy,
|
Redundancy: redundancy,
|
||||||
RemotePieces: pieces,
|
RemotePieces: pieces,
|
||||||
},
|
},
|
||||||
|
ExpirationDate: expirationDateProto,
|
||||||
}
|
}
|
||||||
|
|
||||||
limits := make([]*pb.OrderLimit2, len(addresedLimits))
|
limits := make([]*pb.OrderLimit2, len(addresedLimits))
|
||||||
@ -329,3 +337,148 @@ func TestCommitSegment(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDoubleCommitSegment(t *testing.T) {
|
||||||
|
testplanet.Run(t, testplanet.Config{
|
||||||
|
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
|
||||||
|
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||||
|
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
|
||||||
|
|
||||||
|
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pointer, limits := runCreateSegment(ctx, t, metainfo)
|
||||||
|
|
||||||
|
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing create request or request expired")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommitSegmentPointer(t *testing.T) {
|
||||||
|
// all tests needs to generate error
|
||||||
|
tests := []struct {
|
||||||
|
// defines how modify pointer before CommitSegment
|
||||||
|
Modify func(pointer *pb.Pointer)
|
||||||
|
ErrorMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.ExpirationDate.Seconds += 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer expiration date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.MinReq += 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer redundancy scheme date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.RepairThreshold += 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer redundancy scheme date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.SuccessThreshold += 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer redundancy scheme date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.Total += 100
|
||||||
|
},
|
||||||
|
// this error is triggered earlier then Create/Commit RS comparison
|
||||||
|
ErrorMessage: "invalid no order limit for piece",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.ErasureShareSize += 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer redundancy scheme date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Remote.Redundancy.Type = 100
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer redundancy scheme date does not match requested one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Modify: func(pointer *pb.Pointer) {
|
||||||
|
pointer.Type = pb.Pointer_INLINE
|
||||||
|
},
|
||||||
|
ErrorMessage: "pointer type is INLINE but remote segment is set",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testplanet.Run(t, testplanet.Config{
|
||||||
|
SatelliteCount: 1, StorageNodeCount: 6, UplinkCount: 1,
|
||||||
|
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||||
|
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
|
||||||
|
|
||||||
|
metainfo, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], apiKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
pointer, limits := runCreateSegment(ctx, t, metainfo)
|
||||||
|
test.Modify(pointer)
|
||||||
|
|
||||||
|
_, err = metainfo.CommitSegment(ctx, "myBucketName", "file/path", -1, pointer, limits)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), test.ErrorMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCreateSegment(ctx context.Context, t *testing.T, metainfo metainfo.Client) (*pb.Pointer, []*pb.OrderLimit2) {
|
||||||
|
pointer := createTestPointer(t)
|
||||||
|
expirationDate, err := ptypes.Timestamp(pointer.ExpirationDate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addressedLimits, rootPieceID, err := metainfo.CreateSegment(ctx, "myBucketName", "file/path", -1, pointer.Remote.Redundancy, memory.MiB.Int64(), expirationDate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pointer.Remote.RootPieceId = rootPieceID
|
||||||
|
pointer.Remote.RemotePieces[0].NodeId = addressedLimits[0].Limit.StorageNodeId
|
||||||
|
pointer.Remote.RemotePieces[1].NodeId = addressedLimits[1].Limit.StorageNodeId
|
||||||
|
|
||||||
|
limits := make([]*pb.OrderLimit2, len(addressedLimits))
|
||||||
|
for i, addressedLimit := range addressedLimits {
|
||||||
|
limits[i] = addressedLimit.Limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return pointer, limits
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestPointer(t *testing.T) *pb.Pointer {
|
||||||
|
rs := &pb.RedundancyScheme{
|
||||||
|
MinReq: 1,
|
||||||
|
RepairThreshold: 1,
|
||||||
|
SuccessThreshold: 3,
|
||||||
|
Total: 4,
|
||||||
|
ErasureShareSize: 1024,
|
||||||
|
Type: pb.RedundancyScheme_RS,
|
||||||
|
}
|
||||||
|
|
||||||
|
pointer := &pb.Pointer{
|
||||||
|
Type: pb.Pointer_REMOTE,
|
||||||
|
Remote: &pb.RemoteSegment{
|
||||||
|
Redundancy: rs,
|
||||||
|
RemotePieces: []*pb.RemotePiece{
|
||||||
|
&pb.RemotePiece{
|
||||||
|
PieceNum: 0,
|
||||||
|
},
|
||||||
|
&pb.RemotePiece{
|
||||||
|
PieceNum: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpirationDate: ptypes.TimestampNow(),
|
||||||
|
}
|
||||||
|
return pointer
|
||||||
|
}
|
||||||
|
268
satellite/metainfo/validation.go
Normal file
268
satellite/metainfo/validation.go
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||||||
|
// See LICENSE for copying information.
|
||||||
|
|
||||||
|
package metainfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto"
|
||||||
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
|
"github.com/zeebo/errs"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"storj.io/storj/pkg/auth"
|
||||||
|
"storj.io/storj/pkg/macaroon"
|
||||||
|
"storj.io/storj/pkg/pb"
|
||||||
|
"storj.io/storj/pkg/storj"
|
||||||
|
"storj.io/storj/satellite/console"
|
||||||
|
)
|
||||||
|
|
||||||
|
const requestTTL = time.Hour * 4
|
||||||
|
|
||||||
|
// TTLItem keeps association between serial number and ttl
|
||||||
|
type TTLItem struct {
|
||||||
|
serialNumber storj.SerialNumber
|
||||||
|
ttl time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type createRequest struct {
|
||||||
|
Expiration *timestamp.Timestamp
|
||||||
|
Redundancy *pb.RedundancyScheme
|
||||||
|
|
||||||
|
ttl time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type createRequests struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
// orders limit serial number used because with CreateSegment we don't have path yet
|
||||||
|
entries map[storj.SerialNumber]*createRequest
|
||||||
|
|
||||||
|
muTTL sync.Mutex
|
||||||
|
entriesTTL []*TTLItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCreateRequests() *createRequests {
|
||||||
|
return &createRequests{
|
||||||
|
entries: make(map[storj.SerialNumber]*createRequest),
|
||||||
|
entriesTTL: make([]*TTLItem, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (requests *createRequests) Put(serialNumber storj.SerialNumber, createRequest *createRequest) {
|
||||||
|
ttl := time.Now().Add(requestTTL)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
requests.muTTL.Lock()
|
||||||
|
requests.entriesTTL = append(requests.entriesTTL, &TTLItem{
|
||||||
|
serialNumber: serialNumber,
|
||||||
|
ttl: ttl,
|
||||||
|
})
|
||||||
|
requests.muTTL.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
createRequest.ttl = ttl
|
||||||
|
requests.mu.Lock()
|
||||||
|
requests.entries[serialNumber] = createRequest
|
||||||
|
requests.mu.Unlock()
|
||||||
|
|
||||||
|
go requests.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (requests *createRequests) Load(serialNumber storj.SerialNumber) (*createRequest, bool) {
|
||||||
|
requests.mu.RLock()
|
||||||
|
request, found := requests.entries[serialNumber]
|
||||||
|
if request != nil && request.ttl.Before(time.Now()) {
|
||||||
|
request = nil
|
||||||
|
found = false
|
||||||
|
}
|
||||||
|
requests.mu.RUnlock()
|
||||||
|
|
||||||
|
return request, found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (requests *createRequests) Remove(serialNumber storj.SerialNumber) {
|
||||||
|
requests.mu.Lock()
|
||||||
|
delete(requests.entries, serialNumber)
|
||||||
|
requests.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (requests *createRequests) cleanup() {
|
||||||
|
requests.muTTL.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
remove := make([]storj.SerialNumber, 0)
|
||||||
|
newStart := 0
|
||||||
|
for i, item := range requests.entriesTTL {
|
||||||
|
if item.ttl.Before(now) {
|
||||||
|
remove = append(remove, item.serialNumber)
|
||||||
|
newStart = i + 1
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requests.entriesTTL = requests.entriesTTL[newStart:]
|
||||||
|
requests.muTTL.Unlock()
|
||||||
|
|
||||||
|
for _, serialNumber := range remove {
|
||||||
|
requests.Remove(serialNumber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
keyData, ok := auth.GetAPIKey(ctx)
|
||||||
|
if !ok {
|
||||||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := macaroon.ParseAPIKey(string(keyData))
|
||||||
|
if err != nil {
|
||||||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
|
||||||
|
if err != nil {
|
||||||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revocations are currently handled by just deleting the key.
|
||||||
|
err = key.Check(ctx, keyInfo.Secret, action, nil)
|
||||||
|
if err != nil {
|
||||||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
return keyInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validateCreateSegment(ctx context.Context, req *pb.SegmentWriteRequest) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = endpoint.validateRedundancy(ctx, req.Redundancy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validateCommitSegment(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = endpoint.validatePointer(ctx, req.Pointer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Pointer.Type == pb.Pointer_REMOTE {
|
||||||
|
remote := req.Pointer.Remote
|
||||||
|
|
||||||
|
if len(req.OriginalLimits) == 0 {
|
||||||
|
return Error.New("no order limits")
|
||||||
|
}
|
||||||
|
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
|
||||||
|
return Error.New("invalid no order limit for piece")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, piece := range remote.RemotePieces {
|
||||||
|
limit := req.OriginalLimits[piece.PieceNum]
|
||||||
|
|
||||||
|
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit == nil {
|
||||||
|
return Error.New("invalid no order limit for piece")
|
||||||
|
}
|
||||||
|
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
|
||||||
|
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
|
||||||
|
return Error.New("invalid order limit piece id")
|
||||||
|
}
|
||||||
|
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
|
||||||
|
return Error.New("piece NodeID != order limit NodeID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.OriginalLimits) > 0 {
|
||||||
|
createRequest, found := endpoint.createRequests.Load(req.OriginalLimits[0].SerialNumber)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case !found:
|
||||||
|
return Error.New("missing create request or request expired")
|
||||||
|
case !proto.Equal(createRequest.Expiration, req.Pointer.ExpirationDate):
|
||||||
|
return Error.New("pointer expiration date does not match requested one")
|
||||||
|
case !proto.Equal(createRequest.Redundancy, req.Pointer.Remote.Redundancy):
|
||||||
|
return Error.New("pointer redundancy scheme date does not match requested one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
if len(bucket) == 0 {
|
||||||
|
return errs.New("bucket not specified")
|
||||||
|
}
|
||||||
|
if bytes.ContainsAny(bucket, "/") {
|
||||||
|
return errs.New("bucket should not contain slash")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
if pointer == nil {
|
||||||
|
return Error.New("no pointer specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pointer.Type == pb.Pointer_INLINE && pointer.Remote != nil {
|
||||||
|
return Error.New("pointer type is INLINE but remote segment is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO does it all?
|
||||||
|
if pointer.Type == pb.Pointer_REMOTE {
|
||||||
|
if pointer.Remote == nil {
|
||||||
|
return Error.New("no remote segment specified")
|
||||||
|
}
|
||||||
|
if pointer.Remote.RemotePieces == nil {
|
||||||
|
return Error.New("no remote segment pieces specified")
|
||||||
|
}
|
||||||
|
if pointer.Remote.Redundancy == nil {
|
||||||
|
return Error.New("no redundancy scheme specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
// TODO more validation, use validation from eestream.NewRedundancyStrategy
|
||||||
|
if redundancy.ErasureShareSize <= 0 {
|
||||||
|
return Error.New("erasure share size cannot be less than 0")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user