From 73d5c6944a2dceeed446f8f21baa4dc916b33739 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Wed, 14 Dec 2022 19:19:29 -0700 Subject: [PATCH] satellite/audit: merge support for reservoirs Change-Id: Ibbedd2a0043412210159fa2523f9e63d987276c3 --- satellite/audit/reservoir.go | 17 +++++++++++ satellite/audit/reservoir_test.go | 51 +++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/satellite/audit/reservoir.go b/satellite/audit/reservoir.go index ea6a05620..042f23f7c 100644 --- a/satellite/audit/reservoir.go +++ b/satellite/audit/reservoir.go @@ -8,6 +8,8 @@ import ( "math/rand" "time" + "github.com/zeebo/errs" + "storj.io/common/uuid" "storj.io/storj/satellite/metabase" "storj.io/storj/satellite/metabase/segmentloop" @@ -55,6 +57,10 @@ func (reservoir *Reservoir) Keys() []float64 { // article: https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Res func (reservoir *Reservoir) Sample(r *rand.Rand, segment *segmentloop.Segment) { k := -math.Log(r.Float64()) / float64(segment.EncryptedSize) + reservoir.sample(k, segment) +} + +func (reservoir *Reservoir) sample(k float64, segment *segmentloop.Segment) { if reservoir.index < reservoir.size { reservoir.segments[reservoir.index] = *segment reservoir.keys[reservoir.index] = k @@ -73,6 +79,17 @@ func (reservoir *Reservoir) Sample(r *rand.Rand, segment *segmentloop.Segment) { } } +// Merge merges the given reservoir into the first. Both reservoirs must have the same size. +func (reservoir *Reservoir) Merge(operand *Reservoir) error { + if reservoir.size != operand.size { + return errs.New("cannot merge: mismatched size: expected %d but got %d", reservoir.size, operand.size) + } + for i := int8(0); i < operand.index; i++ { + reservoir.sample(operand.keys[i], &operand.segments[i]) + } + return nil +} + // Segment is a segment to audit. type Segment struct { StreamID uuid.UUID diff --git a/satellite/audit/reservoir_test.go b/satellite/audit/reservoir_test.go index df327703c..7e6967fef 100644 --- a/satellite/audit/reservoir_test.go +++ b/satellite/audit/reservoir_test.go @@ -23,13 +23,11 @@ import ( func TestReservoir(t *testing.T) { rng := rand.New(rand.NewSource(time.Now().Unix())) - seg := func(n int) segmentloop.Segment { return segmentloop.Segment{StreamID: uuid.UUID{0: byte(n)}} } - for size := 0; size < maxReservoirSize; size++ { t.Run(fmt.Sprintf("size %d", size), func(t *testing.T) { samples := []segmentloop.Segment{} for i := 0; i < size; i++ { - samples = append(samples, seg(i)) + samples = append(samples, makeSegment(i)) } // If we sample N segments, less than the max, we should record all N @@ -43,6 +41,46 @@ func TestReservoir(t *testing.T) { } } +func TestReservoirMerge(t *testing.T) { + t.Run("merge successful", func(t *testing.T) { + // Use a fixed rng so we get deterministic sampling results. + segments := []segmentloop.Segment{ + makeSegment(0), makeSegment(1), makeSegment(2), + makeSegment(3), makeSegment(4), makeSegment(5), + } + rng := rand.New(rand.NewSource(999)) + r1 := NewReservoir(3) + r1.Sample(rng, &segments[0]) + r1.Sample(rng, &segments[1]) + r1.Sample(rng, &segments[2]) + + r2 := NewReservoir(3) + r2.Sample(rng, &segments[3]) + r2.Sample(rng, &segments[4]) + r2.Sample(rng, &segments[5]) + + err := r1.Merge(r2) + require.NoError(t, err) + + // Segments should contain a cross section from r1 and r2. If the rng + // changes, this result will likely change too since that will affect + // the keys. and therefore how they are merged. + require.Equal(t, []segmentloop.Segment{ + segments[5], + segments[1], + segments[2], + }, r1.Segments()) + }) + + t.Run("mismatched size", func(t *testing.T) { + r1 := NewReservoir(2) + r2 := NewReservoir(1) + err := r1.Merge(r2) + require.EqualError(t, err, "cannot merge: mismatched size: expected 2 but got 1") + }) + +} + func TestReservoirWeights(t *testing.T) { var weight10StreamID = testrand.UUID() var weight5StreamID = testrand.UUID() @@ -159,3 +197,10 @@ type uint64Slice []uint64 func (us uint64Slice) Len() int { return len(us) } func (us uint64Slice) Swap(i, j int) { us[i], us[j] = us[j], us[i] } func (us uint64Slice) Less(i, j int) bool { return us[i] < us[j] } + +func makeSegment(n int) segmentloop.Segment { + return segmentloop.Segment{ + StreamID: uuid.UUID{0: byte(n)}, + EncryptedSize: int32(n * 1000), + } +}