diff --git a/replay/replay.go b/replay/replay.go index 9ad37be..c6cf4f7 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -23,7 +23,7 @@ type AntiReplay struct { // for inbound packets mu sync.Mutex last uint64 - circle [1 << 7]uint64 + circle [NumBlocks]uint64 } func NewAntiReplay() *AntiReplay { diff --git a/replay/replay_test.go b/replay/replay_test.go new file mode 100644 index 0000000..0254d5c --- /dev/null +++ b/replay/replay_test.go @@ -0,0 +1,77 @@ +package replay + +import ( + "encoding/binary" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math/rand" + "testing" +) + +func TestAntiReplay_Verify(t *testing.T) { + t.Run("NoReplays", func(t *testing.T) { + // ASSIGN + r := NewAntiReplay() + + start := rand.Uint64() + sum := make([]byte, 8) + + // ACT + ASSERT + for i := start; i < start+128; i++ { + // ACT + binary.LittleEndian.PutUint64(sum, i) + err := r.Verify(nil, sum) + + // ASSERT + require.Nil(t, err) + } + }) + + t.Run("ImmediateReplay", func(t *testing.T) { + // ASSIGN + r := NewAntiReplay() + + start := rand.Uint64() + sum := make([]byte, 8) + + // ACT + binary.LittleEndian.PutUint64(sum, start) + + err1 := r.Verify(nil, sum) + err2 := r.Verify(nil, sum) + + // ASSERT + require.Nil(t, err1) + require.Equal(t, ErrReplayedPacket, err2) + }) + + t.Run("RandomReplays", func(t *testing.T) { + // ASSIGN + r := NewAntiReplay() + + start := rand.Uint64() + sum := make([]byte, 8) + + random := make([]byte, 512/8) + rand.Read(random) + + // ACT + ASSERT + for i := 0; i < 512; i++ { + // ACT + replay := (random[i/8] & (1 << (i % 8)) > 0) && i != 0 + if !replay { + start++ + } + + binary.LittleEndian.PutUint64(sum, start) + err := r.Verify(nil, sum) + + // ASSERT + if replay { + assert.Equal(t, ErrReplayedPacket, err) + } else { + assert.Nil(t, err) + } + } + }) +}