satellite/audit: prevent accessing unset reservoir segments
This change fixes the access of unset segments and keys on the reservoir when the reservoir size is less than the max OR the number of sampled segments is smaller than the reservoir size. It does so by tucking away the segments and keys behind methods that return properly sized slices into the segments/keys arrays. It also fixes a bug in the housekeeping for the internal index variable that holds onto how many items in the array have been populated. As part of this fix, it changes the type of index to int8, which reduces the size of the reservoir struct by 8 bytes. The tests have been updated to provide better coverage for this case. Change-Id: I3ceb17b692fe456fc4c1ca5d67d35c96aeb0a169
This commit is contained in:
parent
5c3a148d6e
commit
93fad70e4b
@ -68,11 +68,12 @@ func (chore *Chore) Run(ctx context.Context) (err error) {
|
||||
// Add reservoir segments to queue in pseudorandom order.
|
||||
for i := 0; i < chore.config.Slots; i++ {
|
||||
for _, res := range collector.Reservoirs {
|
||||
segments := res.Segments()
|
||||
// Skip reservoir if no segment at this index.
|
||||
if len(res.Segments) <= i {
|
||||
if len(segments) <= i {
|
||||
continue
|
||||
}
|
||||
segment := res.Segments[i]
|
||||
segment := segments[i]
|
||||
segmentKey := SegmentKey{
|
||||
StreamID: segment.StreamID,
|
||||
Position: segment.Position.Encode(),
|
||||
|
@ -58,14 +58,14 @@ func TestAuditCollector(t *testing.T) {
|
||||
for _, node := range planet.StorageNodes {
|
||||
// expect a reservoir for every node
|
||||
require.NotNil(t, observer.Reservoirs[node.ID()])
|
||||
require.True(t, len(observer.Reservoirs[node.ID()].Segments) > 1)
|
||||
require.True(t, len(observer.Reservoirs[node.ID()].Segments()) > 1)
|
||||
|
||||
// Require that len segments are <= 3 even though the Collector was instantiated with 4
|
||||
// because the maxReservoirSize is currently 3.
|
||||
require.True(t, len(observer.Reservoirs[node.ID()].Segments) <= 3)
|
||||
require.True(t, len(observer.Reservoirs[node.ID()].Segments()) <= 3)
|
||||
|
||||
repeats := make(map[audit.Segment]bool)
|
||||
for _, loopSegment := range observer.Reservoirs[node.ID()].Segments {
|
||||
for _, loopSegment := range observer.Reservoirs[node.ID()].Segments() {
|
||||
segment := audit.NewSegment(loopSegment)
|
||||
assert.False(t, repeats[segment], "expected every item in reservoir to be unique")
|
||||
repeats[segment] = true
|
||||
|
@ -17,10 +17,10 @@ const maxReservoirSize = 3
|
||||
|
||||
// Reservoir holds a certain number of segments to reflect a random sample.
|
||||
type Reservoir struct {
|
||||
Segments [maxReservoirSize]segmentloop.Segment
|
||||
Keys [maxReservoirSize]float64
|
||||
segments [maxReservoirSize]segmentloop.Segment
|
||||
keys [maxReservoirSize]float64
|
||||
size int8
|
||||
index int64
|
||||
index int8
|
||||
}
|
||||
|
||||
// NewReservoir instantiates a Reservoir.
|
||||
@ -36,6 +36,16 @@ func NewReservoir(size int) *Reservoir {
|
||||
}
|
||||
}
|
||||
|
||||
// Segments returns the segments picked by the reservoir.
|
||||
func (reservoir *Reservoir) Segments() []segmentloop.Segment {
|
||||
return reservoir.segments[:reservoir.index]
|
||||
}
|
||||
|
||||
// Keys returns the keys for the segments picked by the reservoir.
|
||||
func (reservoir *Reservoir) Keys() []float64 {
|
||||
return reservoir.keys[:reservoir.index]
|
||||
}
|
||||
|
||||
// Sample tries to ensure that each segment passed in has a chance (proportional
|
||||
// to its size) to be in the reservoir when sampling is complete.
|
||||
//
|
||||
@ -45,22 +55,22 @@ func NewReservoir(size int) *Reservoir {
|
||||
// article: https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Res
|
||||
func (reservoir *Reservoir) Sample(r *rand.Rand, segment *segmentloop.Segment) {
|
||||
k := -math.Log(r.Float64()) / float64(segment.EncryptedSize)
|
||||
if reservoir.index < int64(reservoir.size) {
|
||||
reservoir.Segments[reservoir.index] = *segment
|
||||
reservoir.Keys[reservoir.index] = k
|
||||
if reservoir.index < reservoir.size {
|
||||
reservoir.segments[reservoir.index] = *segment
|
||||
reservoir.keys[reservoir.index] = k
|
||||
reservoir.index++
|
||||
} else {
|
||||
max := 0
|
||||
for i := 1; i < int(reservoir.size); i++ {
|
||||
if reservoir.Keys[i] > reservoir.Keys[max] {
|
||||
max := int8(0)
|
||||
for i := int8(1); i < reservoir.size; i++ {
|
||||
if reservoir.keys[i] > reservoir.keys[max] {
|
||||
max = i
|
||||
}
|
||||
}
|
||||
if k < reservoir.Keys[max] {
|
||||
reservoir.Segments[max] = *segment
|
||||
reservoir.Keys[max] = k
|
||||
if k < reservoir.keys[max] {
|
||||
reservoir.segments[max] = *segment
|
||||
reservoir.keys[max] = k
|
||||
}
|
||||
}
|
||||
reservoir.index++
|
||||
}
|
||||
|
||||
// Segment is a segment to audit.
|
||||
|
@ -5,6 +5,7 @@ package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
@ -20,17 +21,26 @@ import (
|
||||
)
|
||||
|
||||
func TestReservoir(t *testing.T) {
|
||||
rng := rand.New(rand.NewSource(0))
|
||||
r := NewReservoir(3)
|
||||
rng := rand.New(rand.NewSource(time.Now().Unix()))
|
||||
|
||||
seg := func(n byte) *segmentloop.Segment { return &segmentloop.Segment{StreamID: uuid.UUID{0: n}} }
|
||||
seg := func(n int) segmentloop.Segment { return segmentloop.Segment{StreamID: uuid.UUID{0: byte(n)}} }
|
||||
|
||||
// if we sample 3 segments, we should record all 3
|
||||
r.Sample(rng, seg(1))
|
||||
r.Sample(rng, seg(2))
|
||||
r.Sample(rng, seg(3))
|
||||
for size := 0; size < maxReservoirSize; size++ {
|
||||
t.Run(fmt.Sprintf("size %d", size), func(t *testing.T) {
|
||||
samples := []segmentloop.Segment{}
|
||||
for i := 0; i < size; i++ {
|
||||
samples = append(samples, seg(i))
|
||||
}
|
||||
|
||||
require.Equal(t, r.Segments[:], []segmentloop.Segment{*seg(1), *seg(2), *seg(3)})
|
||||
// If we sample N segments, less than the max, we should record all N
|
||||
r := NewReservoir(size)
|
||||
for _, sample := range samples {
|
||||
r.Sample(rng, &sample)
|
||||
}
|
||||
require.Equal(t, samples, r.Segments())
|
||||
require.Len(t, r.Keys(), len(samples))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReservoirWeights(t *testing.T) {
|
||||
@ -81,7 +91,7 @@ func TestReservoirWeights(t *testing.T) {
|
||||
r.Sample(rng, segment)
|
||||
}
|
||||
|
||||
for _, segment := range r.Segments {
|
||||
for _, segment := range r.Segments() {
|
||||
streamIDCountsMap[segment.StreamID]++
|
||||
}
|
||||
|
||||
@ -121,7 +131,7 @@ func TestReservoirBias(t *testing.T) {
|
||||
binary.BigEndian.PutUint64(seg.StreamID[0:8], uint64(n)<<(64-useBits))
|
||||
res.Sample(rng, &seg)
|
||||
}
|
||||
for i, seg := range res.Segments {
|
||||
for i, seg := range res.Segments() {
|
||||
num := binary.BigEndian.Uint64(seg.StreamID[0:8]) >> (64 - useBits)
|
||||
numsSelected[r*reservoirSize+i] = num
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user