repeats #22
78
replay/replay.go
Normal file
78
replay/replay.go
Normal file
@ -0,0 +1,78 @@
|
||||
package replay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var ErrReplayedPacket = errors.New("replayed packet")
|
||||
|
||||
const (
|
||||
BlockBits = 64
|
||||
NumBlocks = 128
|
||||
WindowSize = (NumBlocks - 1) * BlockBits
|
||||
)
|
||||
|
||||
type AntiReplay struct {
|
||||
// for outbound packets
|
||||
next uint64
|
||||
|
||||
// for inbound packets
|
||||
mu sync.Mutex
|
||||
last uint64
|
||||
circle [NumBlocks]uint64
|
||||
}
|
||||
|
||||
func NewAntiReplay() *AntiReplay {
|
||||
return &AntiReplay{next: rand.Uint64()}
|
||||
}
|
||||
|
||||
func (a *AntiReplay) CodeLength() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func (a *AntiReplay) Generate([]byte) (out []byte) {
|
||||
out = make([]byte, a.CodeLength())
|
||||
|
||||
s := atomic.AddUint64(&a.next, 1)
|
||||
binary.LittleEndian.PutUint64(out, s)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *AntiReplay) Verify(_, sum []byte) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
s := binary.LittleEndian.Uint64(sum)
|
||||
|
||||
indexBlock := s >> 6
|
||||
if s > a.last {
|
||||
current := a.last >> 6
|
||||
diff := indexBlock - current
|
||||
if diff > NumBlocks {
|
||||
diff = NumBlocks // number so far away the whole ring is fresh
|
||||
}
|
||||
for i := current; i <= current+diff; i++ {
|
||||
a.circle[i&(NumBlocks-1)] = 0 // clear the ones skipped over
|
||||
}
|
||||
a.last = s
|
||||
} else if a.last-s > WindowSize {
|
||||
return ErrReplayedPacket
|
||||
}
|
||||
|
||||
indexBlock &= NumBlocks - 1
|
||||
indexBit := s & (BlockBits - 1)
|
||||
|
||||
prev := a.circle[indexBlock]
|
||||
a.circle[indexBlock] = prev | 1<<indexBit
|
||||
|
||||
if prev == a.circle[indexBlock] {
|
||||
return ErrReplayedPacket
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
77
replay/replay_test.go
Normal file
77
replay/replay_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package replay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAntiReplay_Verify(t *testing.T) {
|
||||
t.Run("NoReplays", func(t *testing.T) {
|
||||
// ASSIGN
|
||||
r := NewAntiReplay()
|
||||
|
||||
start := rand.Uint64()
|
||||
sum := make([]byte, 8)
|
||||
|
||||
// ACT + ASSERT
|
||||
for i := start; i < start+128; i++ {
|
||||
// ACT
|
||||
binary.LittleEndian.PutUint64(sum, i)
|
||||
err := r.Verify(nil, sum)
|
||||
|
||||
// ASSERT
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImmediateReplay", func(t *testing.T) {
|
||||
// ASSIGN
|
||||
r := NewAntiReplay()
|
||||
|
||||
start := rand.Uint64()
|
||||
sum := make([]byte, 8)
|
||||
|
||||
// ACT
|
||||
binary.LittleEndian.PutUint64(sum, start)
|
||||
|
||||
err1 := r.Verify(nil, sum)
|
||||
err2 := r.Verify(nil, sum)
|
||||
|
||||
// ASSERT
|
||||
require.Nil(t, err1)
|
||||
require.Equal(t, ErrReplayedPacket, err2)
|
||||
})
|
||||
|
||||
t.Run("RandomReplays", func(t *testing.T) {
|
||||
// ASSIGN
|
||||
r := NewAntiReplay()
|
||||
|
||||
start := rand.Uint64()
|
||||
sum := make([]byte, 8)
|
||||
|
||||
random := make([]byte, 512/8)
|
||||
rand.Read(random)
|
||||
|
||||
// ACT + ASSERT
|
||||
for i := 0; i < 512; i++ {
|
||||
// ACT
|
||||
replay := (random[i/8]&(1<<(i%8)) > 0) && i != 0
|
||||
if !replay {
|
||||
start++
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint64(sum, start)
|
||||
err := r.Verify(nil, sum)
|
||||
|
||||
// ASSERT
|
||||
if replay {
|
||||
assert.Equal(t, ErrReplayedPacket, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user