diff --git a/replay/replay.go b/replay/replay.go new file mode 100644 index 0000000..9ad37be --- /dev/null +++ b/replay/replay.go @@ -0,0 +1,78 @@ +package replay + +import ( + "encoding/binary" + "errors" + "math/rand" + "sync" + "sync/atomic" +) + +var ErrReplayedPacket = errors.New("replayed packet") + +const ( + BlockBits = 64 + NumBlocks = 128 + WindowSize = (NumBlocks - 1) * BlockBits +) + +type AntiReplay struct { + // for outbound packets + next uint64 + + // for inbound packets + mu sync.Mutex + last uint64 + circle [1 << 7]uint64 +} + +func NewAntiReplay() *AntiReplay { + return &AntiReplay{next: rand.Uint64()} +} + +func (a *AntiReplay) CodeLength() int { + return 8 +} + +func (a *AntiReplay) Generate([]byte) (out []byte) { + out = make([]byte, a.CodeLength()) + + s := atomic.AddUint64(&a.next, 1) + binary.LittleEndian.PutUint64(out, s) + + return +} + +func (a *AntiReplay) Verify(_, sum []byte) error { + a.mu.Lock() + defer a.mu.Unlock() + + s := binary.LittleEndian.Uint64(sum) + + indexBlock := s >> 6 + if s > a.last { + current := a.last >> 6 + diff := indexBlock - current + if diff > NumBlocks { + diff = NumBlocks // number so far away the whole ring is fresh + } + for i := current; i <= current+diff; i++ { + a.circle[i&(NumBlocks-1)] = 0 // clear the ones skipped over + } + a.last = s + } else if a.last-s > WindowSize { + return ErrReplayedPacket + } + + indexBlock &= NumBlocks - 1 + indexBit := s & (BlockBits - 1) + + prev := a.circle[indexBlock] + a.circle[indexBlock] = prev | 1<