satellite/gracefulexit: refactor concurrency (#3624)

Update PendingMap structure to also handle concurrency control between the sending and receiving sides of the graceful exit endpoint.
This commit is contained in:
Maximillian von Briesen 2019-11-21 17:03:16 -05:00 committed by GitHub
parent b7a8ffcdff
commit 1339252cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 495 additions and 163 deletions

View File

@ -18,6 +18,7 @@ import (
"storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/rpc/rpcstatus"
"storj.io/storj/pkg/signing" "storj.io/storj/pkg/signing"
"storj.io/storj/pkg/storj" "storj.io/storj/pkg/storj"
"storj.io/storj/private/errs2"
"storj.io/storj/private/sync2" "storj.io/storj/private/sync2"
"storj.io/storj/satellite/metainfo" "storj.io/storj/satellite/metainfo"
"storj.io/storj/satellite/orders" "storj.io/storj/satellite/orders"
@ -59,61 +60,6 @@ type Endpoint struct {
recvTimeout time.Duration recvTimeout time.Duration
} }
type pendingTransfer struct {
path []byte
pieceSize int64
satelliteMessage *pb.SatelliteMessage
originalPointer *pb.Pointer
pieceNum int32
}
// pendingMap for managing concurrent access to the pending transfer map.
type pendingMap struct {
mu sync.RWMutex
data map[storj.PieceID]*pendingTransfer
}
// newPendingMap creates a new pendingMap and instantiates the map.
func newPendingMap() *pendingMap {
newData := make(map[storj.PieceID]*pendingTransfer)
return &pendingMap{
data: newData,
}
}
// put adds to the map.
func (pm *pendingMap) put(pieceID storj.PieceID, pendingTransfer *pendingTransfer) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.data[pieceID] = pendingTransfer
}
// get returns the pending transfer item from the map, if it exists.
func (pm *pendingMap) get(pieceID storj.PieceID) (pendingTransfer *pendingTransfer, ok bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
pendingTransfer, ok = pm.data[pieceID]
return pendingTransfer, ok
}
// length returns the number of elements in the map.
func (pm *pendingMap) length() int {
pm.mu.RLock()
defer pm.mu.RUnlock()
return len(pm.data)
}
// delete removes the pending transfer item from the map.
func (pm *pendingMap) delete(pieceID storj.PieceID) {
pm.mu.Lock()
defer pm.mu.Unlock()
delete(pm.data, pieceID)
}
// connectionsTracker for tracking ongoing connections on this api server // connectionsTracker for tracking ongoing connections on this api server
type connectionsTracker struct { type connectionsTracker struct {
mu sync.RWMutex mu sync.RWMutex
@ -224,83 +170,60 @@ func (endpoint *Endpoint) doProcess(stream processStream) (err error) {
return nil return nil
} }
// these are used to synchronize the "incomplete transfer loop" with the main thread (storagenode receive loop)
morePiecesFlag := true
loopRunningFlag := true
errChan := make(chan error, 1)
processMu := &sync.Mutex{}
processCond := sync.NewCond(processMu)
handleError := func(err error) error {
errChan <- err
close(errChan)
return rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error())
}
// maps pieceIDs to pendingTransfers to keep track of ongoing piece transfer requests // maps pieceIDs to pendingTransfers to keep track of ongoing piece transfer requests
pending := newPendingMap() // and handles concurrency between sending logic and receiving logic
pending := NewPendingMap()
var group errgroup.Group var group errgroup.Group
group.Go(func() error { group.Go(func() error {
incompleteLoop := sync2.NewCycle(endpoint.interval) incompleteLoop := sync2.NewCycle(endpoint.interval)
defer func() {
processMu.Lock()
loopRunningFlag = false
processCond.Broadcast()
processMu.Unlock()
}()
// we cancel this context in all situations where we want to exit the loop
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return incompleteLoop.Run(ctx, func(ctx context.Context) error { loopErr := incompleteLoop.Run(ctx, func(ctx context.Context) error {
if pending.length() == 0 { if pending.Length() == 0 {
incomplete, err := endpoint.db.GetIncompleteNotFailed(ctx, nodeID, endpoint.config.EndpointBatchSize, 0) incomplete, err := endpoint.db.GetIncompleteNotFailed(ctx, nodeID, endpoint.config.EndpointBatchSize, 0)
if err != nil { if err != nil {
return handleError(err) cancel()
return pending.DoneSending(err)
} }
if len(incomplete) == 0 { if len(incomplete) == 0 {
incomplete, err = endpoint.db.GetIncompleteFailed(ctx, nodeID, endpoint.config.MaxFailuresPerPiece, endpoint.config.EndpointBatchSize, 0) incomplete, err = endpoint.db.GetIncompleteFailed(ctx, nodeID, endpoint.config.MaxFailuresPerPiece, endpoint.config.EndpointBatchSize, 0)
if err != nil { if err != nil {
return handleError(err) cancel()
return pending.DoneSending(err)
} }
} }
if len(incomplete) == 0 { if len(incomplete) == 0 {
endpoint.log.Debug("no more pieces to transfer for node", zap.Stringer("Node ID", nodeID)) endpoint.log.Debug("no more pieces to transfer for node", zap.Stringer("Node ID", nodeID))
processMu.Lock()
morePiecesFlag = false
processMu.Unlock()
cancel() cancel()
return nil return pending.DoneSending(nil)
} }
for _, inc := range incomplete { for _, inc := range incomplete {
err = endpoint.processIncomplete(ctx, stream, pending, inc) err = endpoint.processIncomplete(ctx, stream, pending, inc)
if err != nil { if err != nil {
return handleError(err) cancel()
return pending.DoneSending(err)
} }
} }
if pending.length() > 0 {
processCond.Broadcast()
}
} }
return nil return nil
}) })
return errs2.IgnoreCanceled(loopErr)
}) })
for { for {
select { finishedPromise := pending.IsFinishedPromise()
case <-errChan: finished, err := finishedPromise.Wait(ctx)
return group.Wait() if err != nil {
default: return rpcstatus.Error(rpcstatus.Internal, err.Error())
} }
pendingCount := pending.length() // if there is no more work to receive send complete
if finished {
processMu.Lock()
// if there are no more transfers and the pending queue is empty, send complete
if !morePiecesFlag && pendingCount == 0 {
processMu.Unlock()
isDisqualified, err := endpoint.handleDisqualifiedNode(ctx, nodeID) isDisqualified, err := endpoint.handleDisqualifiedNode(ctx, nodeID)
if err != nil { if err != nil {
return rpcstatus.Error(rpcstatus.Internal, err.Error()) return rpcstatus.Error(rpcstatus.Internal, err.Error())
@ -320,24 +243,8 @@ func (endpoint *Endpoint) doProcess(stream processStream) (err error) {
return rpcstatus.Error(rpcstatus.Internal, err.Error()) return rpcstatus.Error(rpcstatus.Internal, err.Error())
} }
break break
} else if pendingCount == 0 {
// otherwise, wait for incomplete loop
processCond.Wait()
select {
case <-ctx.Done():
processMu.Unlock()
return ctx.Err()
default:
} }
// if pending count is still 0 and the loop has exited, return
if pending.length() == 0 && !loopRunningFlag {
processMu.Unlock()
continue
}
}
processMu.Unlock()
done := make(chan struct{}) done := make(chan struct{})
var request *pb.StorageNodeMessage var request *pb.StorageNodeMessage
var recvErr error var recvErr error
@ -409,14 +316,13 @@ func (endpoint *Endpoint) doProcess(stream processStream) (err error) {
if err := group.Wait(); err != nil { if err := group.Wait(); err != nil {
if !errs.Is(err, context.Canceled) { if !errs.Is(err, context.Canceled) {
return rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) return rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error())
} }
} }
return nil return nil
} }
func (endpoint *Endpoint) processIncomplete(ctx context.Context, stream processStream, pending *pendingMap, incomplete *TransferQueueItem) error { func (endpoint *Endpoint) processIncomplete(ctx context.Context, stream processStream, pending *PendingMap, incomplete *TransferQueueItem) error {
nodeID := incomplete.NodeID nodeID := incomplete.NodeID
if incomplete.OrderLimitSendCount >= endpoint.config.MaxOrderLimitSendCount { if incomplete.OrderLimitSendCount >= endpoint.config.MaxOrderLimitSendCount {
@ -531,23 +437,23 @@ func (endpoint *Endpoint) processIncomplete(ctx context.Context, stream processS
} }
// update pending queue with the transfer item // update pending queue with the transfer item
pending.put(pieceID, &pendingTransfer{ err = pending.Put(pieceID, &PendingTransfer{
path: incomplete.Path, Path: incomplete.Path,
pieceSize: pieceSize, PieceSize: pieceSize,
satelliteMessage: transferMsg, SatelliteMessage: transferMsg,
originalPointer: pointer, OriginalPointer: pointer,
pieceNum: incomplete.PieceNum, PieceNum: incomplete.PieceNum,
}) })
return nil return err
} }
func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStream, pending *pendingMap, exitingNodeID storj.NodeID, message *pb.StorageNodeMessage_Succeeded) (err error) { func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStream, pending *PendingMap, exitingNodeID storj.NodeID, message *pb.StorageNodeMessage_Succeeded) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
originalPieceID := message.Succeeded.OriginalPieceId originalPieceID := message.Succeeded.OriginalPieceId
transfer, ok := pending.get(originalPieceID) transfer, ok := pending.Get(originalPieceID)
if !ok { if !ok {
endpoint.log.Error("Could not find transfer item in pending queue", zap.Stringer("Piece ID", originalPieceID)) endpoint.log.Error("Could not find transfer item in pending queue", zap.Stringer("Piece ID", originalPieceID))
return Error.New("Could not find transfer item in pending queue") return Error.New("Could not find transfer item in pending queue")
@ -558,7 +464,7 @@ func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStr
return Error.Wrap(err) return Error.Wrap(err)
} }
receivingNodeID := transfer.satelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit().StorageNodeId receivingNodeID := transfer.SatelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit().StorageNodeId
// get peerID and signee for new storage node // get peerID and signee for new storage node
peerID, err := endpoint.peerIdentities.Get(ctx, receivingNodeID) peerID, err := endpoint.peerIdentities.Get(ctx, receivingNodeID)
if err != nil { if err != nil {
@ -569,17 +475,17 @@ func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStr
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
transferQueueItem, err := endpoint.db.GetTransferQueueItem(ctx, exitingNodeID, transfer.path, transfer.pieceNum) transferQueueItem, err := endpoint.db.GetTransferQueueItem(ctx, exitingNodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
err = endpoint.updatePointer(ctx, transfer.originalPointer, exitingNodeID, receivingNodeID, string(transfer.path), transfer.pieceNum, transferQueueItem.RootPieceID) err = endpoint.updatePointer(ctx, transfer.OriginalPointer, exitingNodeID, receivingNodeID, string(transfer.Path), transfer.PieceNum, transferQueueItem.RootPieceID)
if err != nil { if err != nil {
// remove the piece from the pending queue so it gets retried // remove the piece from the pending queue so it gets retried
pending.delete(originalPieceID) deleteErr := pending.Delete(originalPieceID)
return Error.Wrap(err) return Error.Wrap(errs.Combine(err, deleteErr))
} }
var failed int64 var failed int64
@ -587,17 +493,20 @@ func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStr
failed = -1 failed = -1
} }
err = endpoint.db.IncrementProgress(ctx, exitingNodeID, transfer.pieceSize, 1, failed) err = endpoint.db.IncrementProgress(ctx, exitingNodeID, transfer.PieceSize, 1, failed)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
err = endpoint.db.DeleteTransferQueueItem(ctx, exitingNodeID, transfer.path, transfer.pieceNum) err = endpoint.db.DeleteTransferQueueItem(ctx, exitingNodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
pending.delete(originalPieceID) err = pending.Delete(originalPieceID)
if err != nil {
return err
}
deleteMsg := &pb.SatelliteMessage{ deleteMsg := &pb.SatelliteMessage{
Message: &pb.SatelliteMessage_DeletePiece{ Message: &pb.SatelliteMessage_DeletePiece{
@ -616,19 +525,19 @@ func (endpoint *Endpoint) handleSucceeded(ctx context.Context, stream processStr
return nil return nil
} }
func (endpoint *Endpoint) handleFailed(ctx context.Context, pending *pendingMap, nodeID storj.NodeID, message *pb.StorageNodeMessage_Failed) (err error) { func (endpoint *Endpoint) handleFailed(ctx context.Context, pending *PendingMap, nodeID storj.NodeID, message *pb.StorageNodeMessage_Failed) (err error) {
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
endpoint.log.Warn("transfer failed", zap.Stringer("Piece ID", message.Failed.OriginalPieceId), zap.Stringer("transfer error", message.Failed.GetError())) endpoint.log.Warn("transfer failed", zap.Stringer("Piece ID", message.Failed.OriginalPieceId), zap.Stringer("transfer error", message.Failed.GetError()))
mon.Meter("graceful_exit_transfer_piece_fail").Mark(1) //locked mon.Meter("graceful_exit_transfer_piece_fail").Mark(1) //locked
pieceID := message.Failed.OriginalPieceId pieceID := message.Failed.OriginalPieceId
transfer, ok := pending.get(pieceID) transfer, ok := pending.Get(pieceID)
if !ok { if !ok {
endpoint.log.Debug("could not find transfer message in pending queue. skipping.", zap.Stringer("Piece ID", pieceID)) endpoint.log.Debug("could not find transfer message in pending queue. skipping.", zap.Stringer("Piece ID", pieceID))
// TODO we should probably error out here so we don't get stuck in a loop with a SN that is not behaving properl // TODO we should probably error out here so we don't get stuck in a loop with a SN that is not behaving properl
} }
transferQueueItem, err := endpoint.db.GetTransferQueueItem(ctx, nodeID, transfer.path, transfer.pieceNum) transferQueueItem, err := endpoint.db.GetTransferQueueItem(ctx, nodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
@ -644,38 +553,36 @@ func (endpoint *Endpoint) handleFailed(ctx context.Context, pending *pendingMap,
// Remove the queue item and remove the node from the pointer. // Remove the queue item and remove the node from the pointer.
// If the pointer is not piece hash verified, do not count this as a failure. // If the pointer is not piece hash verified, do not count this as a failure.
if pb.TransferFailed_Error(errorCode) == pb.TransferFailed_NOT_FOUND { if pb.TransferFailed_Error(errorCode) == pb.TransferFailed_NOT_FOUND {
endpoint.log.Debug("piece not found on node", zap.Stringer("node ID", nodeID), zap.ByteString("path", transfer.path), zap.Int32("piece num", transfer.pieceNum)) endpoint.log.Debug("piece not found on node", zap.Stringer("node ID", nodeID), zap.ByteString("path", transfer.Path), zap.Int32("piece num", transfer.PieceNum))
pointer, err := endpoint.metainfo.Get(ctx, string(transfer.path)) pointer, err := endpoint.metainfo.Get(ctx, string(transfer.Path))
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
remote := pointer.GetRemote() remote := pointer.GetRemote()
if remote == nil { if remote == nil {
err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.path, transfer.pieceNum) err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
pending.delete(pieceID) return pending.Delete(pieceID)
return nil
} }
pieces := remote.GetRemotePieces() pieces := remote.GetRemotePieces()
var nodePiece *pb.RemotePiece var nodePiece *pb.RemotePiece
for _, piece := range pieces { for _, piece := range pieces {
if piece.NodeId == nodeID && piece.PieceNum == transfer.pieceNum { if piece.NodeId == nodeID && piece.PieceNum == transfer.PieceNum {
nodePiece = piece nodePiece = piece
} }
} }
if nodePiece == nil { if nodePiece == nil {
err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.path, transfer.pieceNum) err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
pending.delete(pieceID) return pending.Delete(pieceID)
return nil
} }
_, err = endpoint.metainfo.UpdatePieces(ctx, string(transfer.path), pointer, nil, []*pb.RemotePiece{nodePiece}) _, err = endpoint.metainfo.UpdatePieces(ctx, string(transfer.Path), pointer, nil, []*pb.RemotePiece{nodePiece})
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
@ -689,13 +596,11 @@ func (endpoint *Endpoint) handleFailed(ctx context.Context, pending *pendingMap,
} }
} }
err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.path, transfer.pieceNum) err = endpoint.db.DeleteTransferQueueItem(ctx, nodeID, transfer.Path, transfer.PieceNum)
if err != nil { if err != nil {
return Error.Wrap(err) return Error.Wrap(err)
} }
pending.delete(pieceID) return pending.Delete(pieceID)
return nil
} }
transferQueueItem.LastFailedAt = &now transferQueueItem.LastFailedAt = &now
@ -714,9 +619,7 @@ func (endpoint *Endpoint) handleFailed(ctx context.Context, pending *pendingMap,
} }
} }
pending.delete(pieceID) return pending.Delete(pieceID)
return nil
} }
func (endpoint *Endpoint) handleDisqualifiedNode(ctx context.Context, nodeID storj.NodeID) (isDisqualified bool, err error) { func (endpoint *Endpoint) handleDisqualifiedNode(ctx context.Context, nodeID storj.NodeID) (isDisqualified bool, err error) {

View File

@ -0,0 +1,172 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package gracefulexit
import (
"context"
"sync"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj"
)
// PendingFinishedPromise for waiting for information about finished state.
type PendingFinishedPromise struct {
addedWorkChan chan struct{}
finishCalledChan chan struct{}
finishErr error
}
func newFinishedPromise() *PendingFinishedPromise {
return &PendingFinishedPromise{
addedWorkChan: make(chan struct{}),
finishCalledChan: make(chan struct{}),
}
}
// Wait should be called (once) after acquiring the finished promise and will return whether the pending map is finished.
func (promise *PendingFinishedPromise) Wait(ctx context.Context) (bool, error) {
select {
case <-ctx.Done():
return true, ctx.Err()
case <-promise.addedWorkChan:
return false, nil
case <-promise.finishCalledChan:
return true, promise.finishErr
}
}
func (promise *PendingFinishedPromise) addedWork() {
close(promise.addedWorkChan)
}
func (promise *PendingFinishedPromise) finishCalled(err error) {
promise.finishErr = err
close(promise.finishCalledChan)
}
// PendingTransfer is the representation of work on the pending map.
// It contains information about a transfer request that has been sent to a storagenode by the satellite.
type PendingTransfer struct {
Path []byte
PieceSize int64
SatelliteMessage *pb.SatelliteMessage
OriginalPointer *pb.Pointer
PieceNum int32
}
// PendingMap for managing concurrent access to the pending transfer map.
type PendingMap struct {
mu sync.RWMutex
data map[storj.PieceID]*PendingTransfer
doneSending bool
doneSendingErr error
finishedPromise *PendingFinishedPromise
}
// NewPendingMap creates a new PendingMap.
func NewPendingMap() *PendingMap {
newData := make(map[storj.PieceID]*PendingTransfer)
return &PendingMap{
data: newData,
}
}
// Put adds work to the map. If there is already work associated with this piece ID it returns an error.
// If there is a PendingFinishedPromise waiting, that promise is updated to return false.
func (pm *PendingMap) Put(pieceID storj.PieceID, pendingTransfer *PendingTransfer) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.doneSending {
return Error.New("cannot add work: pending map already finished")
}
if _, ok := pm.data[pieceID]; ok {
return Error.New("piece ID already exists in pending map")
}
pm.data[pieceID] = pendingTransfer
if pm.finishedPromise != nil {
pm.finishedPromise.addedWork()
pm.finishedPromise = nil
}
return nil
}
// Get returns the pending transfer item from the map, if it exists.
func (pm *PendingMap) Get(pieceID storj.PieceID) (*PendingTransfer, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
pendingTransfer, ok := pm.data[pieceID]
return pendingTransfer, ok
}
// Length returns the number of elements in the map.
func (pm *PendingMap) Length() int {
pm.mu.RLock()
defer pm.mu.RUnlock()
return len(pm.data)
}
// Delete removes the pending transfer item from the map and returns an error if the data does not exist.
func (pm *PendingMap) Delete(pieceID storj.PieceID) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if _, ok := pm.data[pieceID]; !ok {
return Error.New("piece ID does not exist in pending map")
}
delete(pm.data, pieceID)
return nil
}
// IsFinishedPromise returns a promise for the caller to wait on to determine the finished status of the pending map.
// If we have enough information to determine the finished status, we update the promise to have an answer immediately.
// Otherwise, we attach the promise to the pending map to be updated and cleared by either Put or DoneSending (whichever happens first).
func (pm *PendingMap) IsFinishedPromise() *PendingFinishedPromise {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.finishedPromise != nil {
return pm.finishedPromise
}
newPromise := newFinishedPromise()
if len(pm.data) > 0 {
newPromise.addedWork()
return newPromise
}
if pm.doneSending {
newPromise.finishCalled(pm.doneSendingErr)
return newPromise
}
pm.finishedPromise = newPromise
return newPromise
}
// DoneSending is called (with an optional error) when no more work will be added to the map.
// If DoneSending has already been called, an error is returned.
// If a PendingFinishedPromise is waiting on a response, it is updated to return true.
func (pm *PendingMap) DoneSending(err error) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.doneSending {
return Error.New("DoneSending() already called on pending map")
}
if pm.finishedPromise != nil {
pm.finishedPromise.finishCalled(err)
pm.finishedPromise = nil
}
pm.doneSending = true
pm.doneSendingErr = err
return nil
}

View File

@ -0,0 +1,257 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package gracefulexit_test
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"golang.org/x/sync/errgroup"
"storj.io/storj/pkg/pb"
"storj.io/storj/private/errs2"
"storj.io/storj/private/sync2"
"storj.io/storj/private/testcontext"
"storj.io/storj/private/testrand"
"storj.io/storj/satellite/gracefulexit"
)
func TestPendingBasic(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
newWork := &gracefulexit.PendingTransfer{
Path: []byte("testbucket/testfile"),
PieceSize: 10,
SatelliteMessage: &pb.SatelliteMessage{},
OriginalPointer: &pb.Pointer{},
PieceNum: 1,
}
pieceID := testrand.PieceID()
pending := gracefulexit.NewPendingMap()
// put should work
err := pending.Put(pieceID, newWork)
require.NoError(t, err)
// put should return an error if the item already exists
err = pending.Put(pieceID, newWork)
require.Error(t, err)
// get should work
w, ok := pending.Get(pieceID)
require.True(t, ok)
require.True(t, bytes.Equal(newWork.Path, w.Path))
invalidPieceID := testrand.PieceID()
_, ok = pending.Get(invalidPieceID)
require.False(t, ok)
// IsFinished: there is remaining work to be done -> return false immediately
finishedPromise := pending.IsFinishedPromise()
finished, err := finishedPromise.Wait(ctx)
require.False(t, finished)
require.NoError(t, err)
// finished should work
err = pending.DoneSending(nil)
require.NoError(t, err)
// finished should error if already called
err = pending.DoneSending(nil)
require.Error(t, err)
// should not be allowed to Put new work after finished called
err = pending.Put(testrand.PieceID(), newWork)
require.Error(t, err)
// IsFinished: Finish has been called and there is remaining work -> return false
finishedPromise = pending.IsFinishedPromise()
finished, err = finishedPromise.Wait(ctx)
require.False(t, finished)
require.NoError(t, err)
// delete should work
err = pending.Delete(pieceID)
require.NoError(t, err)
_, ok = pending.Get(pieceID)
require.False(t, ok)
// delete should return an error if the item does not exist
err = pending.Delete(pieceID)
require.Error(t, err)
// IsFinished: Finish has been called and there is no remaining work -> return true
finishedPromise = pending.IsFinishedPromise()
finished, err = finishedPromise.Wait(ctx)
require.True(t, finished)
require.NoError(t, err)
}
// TestPendingIsFinishedWorkAdded ensures that pending.IsFinished blocks if there is no work, then returns false when new work is added
func TestPendingIsFinishedWorkAdded(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
newWork := &gracefulexit.PendingTransfer{
Path: []byte("testbucket/testfile"),
PieceSize: 10,
SatelliteMessage: &pb.SatelliteMessage{},
OriginalPointer: &pb.Pointer{},
PieceNum: 1,
}
pieceID := testrand.PieceID()
pending := gracefulexit.NewPendingMap()
fence := sync2.Fence{}
var group errgroup.Group
group.Go(func() error {
// expect no work
size := pending.Length()
require.EqualValues(t, size, 0)
finishedPromise := pending.IsFinishedPromise()
// wait for work to be added
fence.Release()
finished, err := finishedPromise.Wait(ctx)
require.False(t, finished)
require.NoError(t, err)
// expect new work was added
size = pending.Length()
require.EqualValues(t, size, 1)
return nil
})
group.Go(func() error {
// wait for IsFinishedPromise call before adding work
require.True(t, fence.Wait(ctx))
err := pending.Put(pieceID, newWork)
require.NoError(t, err)
return nil
})
require.NoError(t, group.Wait())
}
// TestPendingIsFinishedDoneSendingCalled ensures that pending.IsFinished blocks if there is no work, then returns true when DoneSending is called
func TestPendingIsFinishedDoneSendingCalled(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
pending := gracefulexit.NewPendingMap()
fence := sync2.Fence{}
var group errgroup.Group
group.Go(func() error {
finishedPromise := pending.IsFinishedPromise()
fence.Release()
finished, err := finishedPromise.Wait(ctx)
require.True(t, finished)
require.NoError(t, err)
return nil
})
group.Go(func() error {
// wait for IsFinishedPromise call before finishing
require.True(t, fence.Wait(ctx))
err := pending.DoneSending(nil)
require.NoError(t, err)
return nil
})
require.NoError(t, group.Wait())
}
// TestPendingIsFinishedCtxCanceled ensures that pending.IsFinished blocks if there is no work, then returns true when context is canceled
func TestPendingIsFinishedCtxCanceled(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
pending := gracefulexit.NewPendingMap()
ctx2, cancel := context.WithCancel(ctx)
fence := sync2.Fence{}
var group errgroup.Group
group.Go(func() error {
finishedPromise := pending.IsFinishedPromise()
fence.Release()
finished, err := finishedPromise.Wait(ctx2)
require.True(t, finished)
require.Error(t, err)
require.True(t, errs2.IsCanceled(err))
return nil
})
group.Go(func() error {
// wait for IsFinishedPromise call before canceling
require.True(t, fence.Wait(ctx))
cancel()
return nil
})
require.NoError(t, group.Wait())
}
// TestPendingIsFinishedDoneSendingCalledError ensures that pending.IsFinished blocks if there is no work, then returns true with an error when DoneSending is called with an error
func TestPendingIsFinishedDoneSendingCalledError(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
pending := gracefulexit.NewPendingMap()
finishErr := errs.New("test error")
fence := sync2.Fence{}
var group errgroup.Group
group.Go(func() error {
finishedPromise := pending.IsFinishedPromise()
fence.Release()
finished, err := finishedPromise.Wait(ctx)
require.True(t, finished)
require.Error(t, err)
require.Equal(t, finishErr, err)
return nil
})
group.Go(func() error {
// wait for IsFinishedPromise call before finishing
require.True(t, fence.Wait(ctx))
err := pending.DoneSending(finishErr)
require.NoError(t, err)
return nil
})
require.NoError(t, group.Wait())
}
// TestPendingIsFinishedDoneSendingCalledError2 ensures that pending.IsFinished returns an error if DoneSending was already called with an error.
func TestPendingIsFinishedDoneSendingCalledError2(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
pending := gracefulexit.NewPendingMap()
finishErr := errs.New("test error")
err := pending.DoneSending(finishErr)
require.NoError(t, err)
finishedPromise := pending.IsFinishedPromise()
finished, err := finishedPromise.Wait(ctx)
require.True(t, finished)
require.Error(t, err)
require.Equal(t, finishErr, err)
}

View File

@ -12,30 +12,30 @@ import (
"storj.io/storj/pkg/signing" "storj.io/storj/pkg/signing"
) )
func (endpoint *Endpoint) validatePendingTransfer(ctx context.Context, transfer *pendingTransfer) error { func (endpoint *Endpoint) validatePendingTransfer(ctx context.Context, transfer *PendingTransfer) error {
if transfer.satelliteMessage == nil { if transfer.SatelliteMessage == nil {
return Error.New("Satellite message cannot be nil") return Error.New("Satellite message cannot be nil")
} }
if transfer.satelliteMessage.GetTransferPiece() == nil { if transfer.SatelliteMessage.GetTransferPiece() == nil {
return Error.New("Satellite message transfer piece cannot be nil") return Error.New("Satellite message transfer piece cannot be nil")
} }
if transfer.satelliteMessage.GetTransferPiece().GetAddressedOrderLimit() == nil { if transfer.SatelliteMessage.GetTransferPiece().GetAddressedOrderLimit() == nil {
return Error.New("Addressed order limit on transfer piece cannot be nil") return Error.New("Addressed order limit on transfer piece cannot be nil")
} }
if transfer.satelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit() == nil { if transfer.SatelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit() == nil {
return Error.New("Addressed order limit on transfer piece cannot be nil") return Error.New("Addressed order limit on transfer piece cannot be nil")
} }
if transfer.path == nil { if transfer.Path == nil {
return Error.New("Transfer path cannot be nil") return Error.New("Transfer path cannot be nil")
} }
if transfer.originalPointer == nil || transfer.originalPointer.GetRemote() == nil { if transfer.OriginalPointer == nil || transfer.OriginalPointer.GetRemote() == nil {
return Error.New("could not get remote pointer from transfer item") return Error.New("could not get remote pointer from transfer item")
} }
return nil return nil
} }
func (endpoint *Endpoint) verifyPieceTransferred(ctx context.Context, message *pb.StorageNodeMessage_Succeeded, transfer *pendingTransfer, receivingNodePeerID *identity.PeerIdentity) error { func (endpoint *Endpoint) verifyPieceTransferred(ctx context.Context, message *pb.StorageNodeMessage_Succeeded, transfer *PendingTransfer, receivingNodePeerID *identity.PeerIdentity) error {
originalOrderLimit := message.Succeeded.GetOriginalOrderLimit() originalOrderLimit := message.Succeeded.GetOriginalOrderLimit()
if originalOrderLimit == nil { if originalOrderLimit == nil {
return ErrInvalidArgument.New("Original order limit cannot be nil") return ErrInvalidArgument.New("Original order limit cannot be nil")
@ -70,8 +70,8 @@ func (endpoint *Endpoint) verifyPieceTransferred(ctx context.Context, message *p
return ErrInvalidArgument.New("Invalid original piece ID") return ErrInvalidArgument.New("Invalid original piece ID")
} }
receivingNodeID := transfer.satelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit().StorageNodeId receivingNodeID := transfer.SatelliteMessage.GetTransferPiece().GetAddressedOrderLimit().GetLimit().StorageNodeId
calculatedNewPieceID := transfer.originalPointer.GetRemote().RootPieceId.Derive(receivingNodeID, transfer.pieceNum) calculatedNewPieceID := transfer.OriginalPointer.GetRemote().RootPieceId.Derive(receivingNodeID, transfer.PieceNum)
if calculatedNewPieceID != replacementPieceHash.PieceId { if calculatedNewPieceID != replacementPieceHash.PieceId {
return ErrInvalidArgument.New("Invalid replacement piece ID") return ErrInvalidArgument.New("Invalid replacement piece ID")
} }