269 lines
7.1 KiB
Go
269 lines
7.1 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package metainfo
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gogo/protobuf/proto"
|
||
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||
|
"github.com/zeebo/errs"
|
||
|
"go.uber.org/zap"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
|
||
|
"storj.io/storj/pkg/auth"
|
||
|
"storj.io/storj/pkg/macaroon"
|
||
|
"storj.io/storj/pkg/pb"
|
||
|
"storj.io/storj/pkg/storj"
|
||
|
"storj.io/storj/satellite/console"
|
||
|
)
|
||
|
|
||
|
const requestTTL = time.Hour * 4
|
||
|
|
||
|
// TTLItem keeps association between serial number and ttl
|
||
|
type TTLItem struct {
|
||
|
serialNumber storj.SerialNumber
|
||
|
ttl time.Time
|
||
|
}
|
||
|
|
||
|
type createRequest struct {
|
||
|
Expiration *timestamp.Timestamp
|
||
|
Redundancy *pb.RedundancyScheme
|
||
|
|
||
|
ttl time.Time
|
||
|
}
|
||
|
|
||
|
type createRequests struct {
|
||
|
mu sync.RWMutex
|
||
|
// orders limit serial number used because with CreateSegment we don't have path yet
|
||
|
entries map[storj.SerialNumber]*createRequest
|
||
|
|
||
|
muTTL sync.Mutex
|
||
|
entriesTTL []*TTLItem
|
||
|
}
|
||
|
|
||
|
func newCreateRequests() *createRequests {
|
||
|
return &createRequests{
|
||
|
entries: make(map[storj.SerialNumber]*createRequest),
|
||
|
entriesTTL: make([]*TTLItem, 0),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (requests *createRequests) Put(serialNumber storj.SerialNumber, createRequest *createRequest) {
|
||
|
ttl := time.Now().Add(requestTTL)
|
||
|
|
||
|
go func() {
|
||
|
requests.muTTL.Lock()
|
||
|
requests.entriesTTL = append(requests.entriesTTL, &TTLItem{
|
||
|
serialNumber: serialNumber,
|
||
|
ttl: ttl,
|
||
|
})
|
||
|
requests.muTTL.Unlock()
|
||
|
}()
|
||
|
|
||
|
createRequest.ttl = ttl
|
||
|
requests.mu.Lock()
|
||
|
requests.entries[serialNumber] = createRequest
|
||
|
requests.mu.Unlock()
|
||
|
|
||
|
go requests.cleanup()
|
||
|
}
|
||
|
|
||
|
func (requests *createRequests) Load(serialNumber storj.SerialNumber) (*createRequest, bool) {
|
||
|
requests.mu.RLock()
|
||
|
request, found := requests.entries[serialNumber]
|
||
|
if request != nil && request.ttl.Before(time.Now()) {
|
||
|
request = nil
|
||
|
found = false
|
||
|
}
|
||
|
requests.mu.RUnlock()
|
||
|
|
||
|
return request, found
|
||
|
}
|
||
|
|
||
|
func (requests *createRequests) Remove(serialNumber storj.SerialNumber) {
|
||
|
requests.mu.Lock()
|
||
|
delete(requests.entries, serialNumber)
|
||
|
requests.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
func (requests *createRequests) cleanup() {
|
||
|
requests.muTTL.Lock()
|
||
|
now := time.Now()
|
||
|
remove := make([]storj.SerialNumber, 0)
|
||
|
newStart := 0
|
||
|
for i, item := range requests.entriesTTL {
|
||
|
if item.ttl.Before(now) {
|
||
|
remove = append(remove, item.serialNumber)
|
||
|
newStart = i + 1
|
||
|
} else {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
requests.entriesTTL = requests.entriesTTL[newStart:]
|
||
|
requests.muTTL.Unlock()
|
||
|
|
||
|
for _, serialNumber := range remove {
|
||
|
requests.Remove(serialNumber)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validateAuth(ctx context.Context, action macaroon.Action) (_ *console.APIKeyInfo, err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
keyData, ok := auth.GetAPIKey(ctx)
|
||
|
if !ok {
|
||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||
|
}
|
||
|
|
||
|
key, err := macaroon.ParseAPIKey(string(keyData))
|
||
|
if err != nil {
|
||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, "Invalid API credential")))
|
||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||
|
}
|
||
|
|
||
|
keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head())
|
||
|
if err != nil {
|
||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||
|
}
|
||
|
|
||
|
// Revocations are currently handled by just deleting the key.
|
||
|
err = key.Check(ctx, keyInfo.Secret, action, nil)
|
||
|
if err != nil {
|
||
|
endpoint.log.Error("unauthorized request", zap.Error(status.Errorf(codes.Unauthenticated, err.Error())))
|
||
|
return nil, status.Errorf(codes.Unauthenticated, "Invalid API credential")
|
||
|
}
|
||
|
|
||
|
return keyInfo, nil
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validateCreateSegment(ctx context.Context, req *pb.SegmentWriteRequest) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
err = endpoint.validateBucket(ctx, req.Bucket)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = endpoint.validateRedundancy(ctx, req.Redundancy)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validateCommitSegment(ctx context.Context, req *pb.SegmentCommitRequest) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
err = endpoint.validateBucket(ctx, req.Bucket)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = endpoint.validatePointer(ctx, req.Pointer)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if req.Pointer.Type == pb.Pointer_REMOTE {
|
||
|
remote := req.Pointer.Remote
|
||
|
|
||
|
if len(req.OriginalLimits) == 0 {
|
||
|
return Error.New("no order limits")
|
||
|
}
|
||
|
if int32(len(req.OriginalLimits)) != remote.Redundancy.Total {
|
||
|
return Error.New("invalid no order limit for piece")
|
||
|
}
|
||
|
|
||
|
for _, piece := range remote.RemotePieces {
|
||
|
limit := req.OriginalLimits[piece.PieceNum]
|
||
|
|
||
|
err := endpoint.orders.VerifyOrderLimitSignature(ctx, limit)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if limit == nil {
|
||
|
return Error.New("invalid no order limit for piece")
|
||
|
}
|
||
|
derivedPieceID := remote.RootPieceId.Derive(piece.NodeId)
|
||
|
if limit.PieceId.IsZero() || limit.PieceId != derivedPieceID {
|
||
|
return Error.New("invalid order limit piece id")
|
||
|
}
|
||
|
if bytes.Compare(piece.NodeId.Bytes(), limit.StorageNodeId.Bytes()) != 0 {
|
||
|
return Error.New("piece NodeID != order limit NodeID")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(req.OriginalLimits) > 0 {
|
||
|
createRequest, found := endpoint.createRequests.Load(req.OriginalLimits[0].SerialNumber)
|
||
|
|
||
|
switch {
|
||
|
case !found:
|
||
|
return Error.New("missing create request or request expired")
|
||
|
case !proto.Equal(createRequest.Expiration, req.Pointer.ExpirationDate):
|
||
|
return Error.New("pointer expiration date does not match requested one")
|
||
|
case !proto.Equal(createRequest.Redundancy, req.Pointer.Remote.Redundancy):
|
||
|
return Error.New("pointer redundancy scheme date does not match requested one")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
if len(bucket) == 0 {
|
||
|
return errs.New("bucket not specified")
|
||
|
}
|
||
|
if bytes.ContainsAny(bucket, "/") {
|
||
|
return errs.New("bucket should not contain slash")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validatePointer(ctx context.Context, pointer *pb.Pointer) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
if pointer == nil {
|
||
|
return Error.New("no pointer specified")
|
||
|
}
|
||
|
|
||
|
if pointer.Type == pb.Pointer_INLINE && pointer.Remote != nil {
|
||
|
return Error.New("pointer type is INLINE but remote segment is set")
|
||
|
}
|
||
|
|
||
|
// TODO does it all?
|
||
|
if pointer.Type == pb.Pointer_REMOTE {
|
||
|
if pointer.Remote == nil {
|
||
|
return Error.New("no remote segment specified")
|
||
|
}
|
||
|
if pointer.Remote.RemotePieces == nil {
|
||
|
return Error.New("no remote segment pieces specified")
|
||
|
}
|
||
|
if pointer.Remote.Redundancy == nil {
|
||
|
return Error.New("no redundancy scheme specified")
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (endpoint *Endpoint) validateRedundancy(ctx context.Context, redundancy *pb.RedundancyScheme) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
// TODO more validation, use validation from eestream.NewRedundancyStrategy
|
||
|
if redundancy.ErasureShareSize <= 0 {
|
||
|
return Error.New("erasure share size cannot be less than 0")
|
||
|
}
|
||
|
return nil
|
||
|
}
|