From 3a34a0df7b6ba108fa35394bcb667e327f54b5d4 Mon Sep 17 00:00:00 2001 From: Jeff Wendling Date: Mon, 15 Jul 2019 15:58:39 -0400 Subject: [PATCH] repair: fix data race in reliability cache (#2561) --- internal/testcontext/context.go | 10 ++++ pkg/datarepair/checker/online.go | 78 ++++++++++++++++++++------- pkg/datarepair/checker/online_test.go | 52 ++++++++++++++++++ 3 files changed, 120 insertions(+), 20 deletions(-) create mode 100644 pkg/datarepair/checker/online_test.go diff --git a/internal/testcontext/context.go b/internal/testcontext/context.go index ed038ff26..af7aa51bd 100644 --- a/internal/testcontext/context.go +++ b/internal/testcontext/context.go @@ -99,6 +99,16 @@ func (ctx *Context) Go(fn func() error) { }) } +// Wait blocks until all of the goroutines launched with Go are done and +// fails the test if any of them returned an error. +func (ctx *Context) Wait() { + ctx.test.Helper() + err := ctx.group.Wait() + if err != nil { + ctx.test.Fatal(err) + } +} + // Check calls fn and checks result func (ctx *Context) Check(fn func() error) { ctx.test.Helper() diff --git a/pkg/datarepair/checker/online.go b/pkg/datarepair/checker/online.go index e40d51e0a..63f44c9b2 100644 --- a/pkg/datarepair/checker/online.go +++ b/pkg/datarepair/checker/online.go @@ -5,6 +5,8 @@ package checker import ( "context" + "sync" + "sync/atomic" "time" "storj.io/storj/pkg/overlay" @@ -14,13 +16,17 @@ import ( // ReliabilityCache caches the reliable nodes for the specified staleness duration // and updates automatically from overlay. -// -// ReliabilityCache is NOT safe for concurrent use. type ReliabilityCache struct { - overlay *overlay.Cache - staleness time.Duration - lastUpdate time.Time - reliable map[storj.NodeID]struct{} + overlay *overlay.Cache + staleness time.Duration + mu sync.Mutex + state atomic.Value // contains immutable *reliabilityState +} + +// reliabilityState +type reliabilityState struct { + reliable map[storj.NodeID]struct{} + created time.Time } // NewReliabilityCache creates a new reliability checking cache. @@ -28,17 +34,37 @@ func NewReliabilityCache(overlay *overlay.Cache, staleness time.Duration) *Relia return &ReliabilityCache{ overlay: overlay, staleness: staleness, - reliable: map[storj.NodeID]struct{}{}, } } // LastUpdate returns when the cache was last updated. -func (cache *ReliabilityCache) LastUpdate() time.Time { return cache.lastUpdate } +func (cache *ReliabilityCache) LastUpdate() time.Time { + if state, ok := cache.state.Load().(*reliabilityState); ok { + return state.created + } + return time.Time{} +} // MissingPieces returns piece indices that are unreliable with the given staleness period. -func (cache *ReliabilityCache) MissingPieces(ctx context.Context, created time.Time, pieces []*pb.RemotePiece) ([]int32, error) { - if created.After(cache.lastUpdate) || time.Since(cache.lastUpdate) > cache.staleness { - err := cache.Refresh(ctx) +func (cache *ReliabilityCache) MissingPieces(ctx context.Context, created time.Time, pieces []*pb.RemotePiece) (_ []int32, err error) { + defer mon.Task()(&ctx)(&err) + + // This code is designed to be very fast in the case where a refresh is not needed: just an + // atomic load from rarely written to bit of shared memory. The general strategy is to first + // read if the state suffices to answer the query. If not (due to it not existing, being + // too stale, etc.), then we acquire the mutex to block other requests that may be stale + // and ensure we only issue one refresh at a time. After acquiring the mutex, we have to + // double check that the state is still stale because some other call may have beat us to + // the acquisition. Only then do we refresh and can then proceed answering the query. + + state, ok := cache.state.Load().(*reliabilityState) + if !ok || created.After(state.created) || time.Since(state.created) > cache.staleness { + cache.mu.Lock() + state, ok = cache.state.Load().(*reliabilityState) + if !ok || created.After(state.created) || time.Since(state.created) > cache.staleness { + state, err = cache.refreshLocked(ctx) + } + cache.mu.Unlock() if err != nil { return nil, err } @@ -46,7 +72,7 @@ func (cache *ReliabilityCache) MissingPieces(ctx context.Context, created time.T var unreliable []int32 for _, piece := range pieces { - if _, ok := cache.reliable[piece.NodeId]; !ok { + if _, ok := state.reliable[piece.NodeId]; !ok { unreliable = append(unreliable, piece.PieceNum) } } @@ -54,21 +80,33 @@ func (cache *ReliabilityCache) MissingPieces(ctx context.Context, created time.T } // Refresh refreshes the cache. -func (cache *ReliabilityCache) Refresh(ctx context.Context) error { - for id := range cache.reliable { - delete(cache.reliable, id) - } +func (cache *ReliabilityCache) Refresh(ctx context.Context) (err error) { + defer mon.Task()(&ctx)(&err) - cache.lastUpdate = time.Now() + cache.mu.Lock() + defer cache.mu.Unlock() + + _, err = cache.refreshLocked(ctx) + return err +} + +// refreshLocked does the refreshes assuming the write mutex is held. +func (cache *ReliabilityCache) refreshLocked(ctx context.Context) (_ *reliabilityState, err error) { + defer mon.Task()(&ctx)(&err) nodes, err := cache.overlay.Reliable(ctx) if err != nil { - return Error.Wrap(err) + return nil, Error.Wrap(err) } + state := &reliabilityState{ + created: time.Now(), + reliable: make(map[storj.NodeID]struct{}, len(nodes)), + } for _, id := range nodes { - cache.reliable[id] = struct{}{} + state.reliable[id] = struct{}{} } - return nil + cache.state.Store(state) + return state, nil } diff --git a/pkg/datarepair/checker/online_test.go b/pkg/datarepair/checker/online_test.go new file mode 100644 index 000000000..202bfad60 --- /dev/null +++ b/pkg/datarepair/checker/online_test.go @@ -0,0 +1,52 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package checker + +import ( + "context" + "testing" + "time" + + "go.uber.org/zap" + + "storj.io/storj/internal/testcontext" + "storj.io/storj/internal/testrand" + "storj.io/storj/pkg/overlay" + "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/storj" +) + +func TestReliabilityCache_Concurrent(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + ocache := overlay.NewCache(zap.NewNop(), fakeOverlayDB{}, overlay.NodeSelectionConfig{}) + rcache := NewReliabilityCache(ocache, time.Millisecond) + + for i := 0; i < 10; i++ { + ctx.Go(func() error { + for i := 0; i < 10000; i++ { + pieces := []*pb.RemotePiece{{NodeId: testrand.NodeID()}} + _, err := rcache.MissingPieces(ctx, time.Now(), pieces) + if err != nil { + return err + } + } + return nil + }) + } + + ctx.Wait() +} + +type fakeOverlayDB struct{ overlay.DB } + +func (fakeOverlayDB) Reliable(context.Context, *overlay.NodeCriteria) (storj.NodeIDList, error) { + return storj.NodeIDList{ + testrand.NodeID(), + testrand.NodeID(), + testrand.NodeID(), + testrand.NodeID(), + }, nil +}