repeats #22

Merged
JakeHillion merged 7 commits from repeats into develop 2021-05-13 18:34:35 +01:00
2 changed files with 155 additions and 0 deletions

78
replay/replay.go Normal file
View File

@ -0,0 +1,78 @@
package replay
import (
"encoding/binary"
"errors"
"math/rand"
"sync"
"sync/atomic"
)
var ErrReplayedPacket = errors.New("replayed packet")
const (
BlockBits = 64
NumBlocks = 128
WindowSize = (NumBlocks - 1) * BlockBits
)
type AntiReplay struct {
// for outbound packets
next uint64
// for inbound packets
mu sync.Mutex
last uint64
circle [NumBlocks]uint64
}
func NewAntiReplay() *AntiReplay {
return &AntiReplay{next: rand.Uint64()}
}
func (a *AntiReplay) CodeLength() int {
return 8
}
func (a *AntiReplay) Generate([]byte) (out []byte) {
out = make([]byte, a.CodeLength())
s := atomic.AddUint64(&a.next, 1)
binary.LittleEndian.PutUint64(out, s)
return
}
func (a *AntiReplay) Verify(_, sum []byte) error {
a.mu.Lock()
defer a.mu.Unlock()
s := binary.LittleEndian.Uint64(sum)
indexBlock := s >> 6
if s > a.last {
current := a.last >> 6
diff := indexBlock - current
if diff > NumBlocks {
diff = NumBlocks // number so far away the whole ring is fresh
}
for i := current; i <= current+diff; i++ {
a.circle[i&(NumBlocks-1)] = 0 // clear the ones skipped over
}
a.last = s
} else if a.last-s > WindowSize {
return ErrReplayedPacket
}
indexBlock &= NumBlocks - 1
indexBit := s & (BlockBits - 1)
prev := a.circle[indexBlock]
a.circle[indexBlock] = prev | 1<<indexBit
if prev == a.circle[indexBlock] {
return ErrReplayedPacket
}
return nil
}

77
replay/replay_test.go Normal file
View File

@ -0,0 +1,77 @@
package replay
import (
"encoding/binary"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"math/rand"
"testing"
)
func TestAntiReplay_Verify(t *testing.T) {
t.Run("NoReplays", func(t *testing.T) {
// ASSIGN
r := NewAntiReplay()
start := rand.Uint64()
sum := make([]byte, 8)
// ACT + ASSERT
for i := start; i < start+128; i++ {
// ACT
binary.LittleEndian.PutUint64(sum, i)
err := r.Verify(nil, sum)
// ASSERT
require.Nil(t, err)
}
})
t.Run("ImmediateReplay", func(t *testing.T) {
// ASSIGN
r := NewAntiReplay()
start := rand.Uint64()
sum := make([]byte, 8)
// ACT
binary.LittleEndian.PutUint64(sum, start)
err1 := r.Verify(nil, sum)
err2 := r.Verify(nil, sum)
// ASSERT
require.Nil(t, err1)
require.Equal(t, ErrReplayedPacket, err2)
})
t.Run("RandomReplays", func(t *testing.T) {
// ASSIGN
r := NewAntiReplay()
start := rand.Uint64()
sum := make([]byte, 8)
random := make([]byte, 512/8)
rand.Read(random)
// ACT + ASSERT
for i := 0; i < 512; i++ {
// ACT
replay := (random[i/8]&(1<<(i%8)) > 0) && i != 0
if !replay {
start++
}
binary.LittleEndian.PutUint64(sum, start)
err := r.Verify(nil, sum)
// ASSERT
if replay {
assert.Equal(t, ErrReplayedPacket, err)
} else {
assert.Nil(t, err)
}
}
})
}