2021-05-13 18:15:43 +01:00
|
|
|
package replay
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/binary"
|
|
|
|
"errors"
|
|
|
|
"math/rand"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
)
|
|
|
|
|
|
|
|
var ErrReplayedPacket = errors.New("replayed packet")
|
|
|
|
|
|
|
|
const (
|
|
|
|
BlockBits = 64
|
|
|
|
NumBlocks = 128
|
|
|
|
WindowSize = (NumBlocks - 1) * BlockBits
|
|
|
|
)
|
|
|
|
|
|
|
|
type AntiReplay struct {
|
|
|
|
// for outbound packets
|
|
|
|
next uint64
|
|
|
|
|
|
|
|
// for inbound packets
|
|
|
|
mu sync.Mutex
|
|
|
|
last uint64
|
2021-05-13 18:28:05 +01:00
|
|
|
circle [NumBlocks]uint64
|
2021-05-13 18:15:43 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewAntiReplay() *AntiReplay {
|
|
|
|
return &AntiReplay{next: rand.Uint64()}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *AntiReplay) CodeLength() int {
|
|
|
|
return 8
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *AntiReplay) Generate([]byte) (out []byte) {
|
|
|
|
out = make([]byte, a.CodeLength())
|
|
|
|
|
|
|
|
s := atomic.AddUint64(&a.next, 1)
|
|
|
|
binary.LittleEndian.PutUint64(out, s)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *AntiReplay) Verify(_, sum []byte) error {
|
|
|
|
a.mu.Lock()
|
|
|
|
defer a.mu.Unlock()
|
|
|
|
|
|
|
|
s := binary.LittleEndian.Uint64(sum)
|
|
|
|
|
|
|
|
indexBlock := s >> 6
|
|
|
|
if s > a.last {
|
|
|
|
current := a.last >> 6
|
|
|
|
diff := indexBlock - current
|
|
|
|
if diff > NumBlocks {
|
|
|
|
diff = NumBlocks // number so far away the whole ring is fresh
|
|
|
|
}
|
|
|
|
for i := current; i <= current+diff; i++ {
|
|
|
|
a.circle[i&(NumBlocks-1)] = 0 // clear the ones skipped over
|
|
|
|
}
|
|
|
|
a.last = s
|
|
|
|
} else if a.last-s > WindowSize {
|
|
|
|
return ErrReplayedPacket
|
|
|
|
}
|
|
|
|
|
|
|
|
indexBlock &= NumBlocks - 1
|
|
|
|
indexBit := s & (BlockBits - 1)
|
|
|
|
|
|
|
|
prev := a.circle[indexBlock]
|
|
|
|
a.circle[indexBlock] = prev | 1<<indexBit
|
|
|
|
|
|
|
|
if prev == a.circle[indexBlock] {
|
|
|
|
return ErrReplayedPacket
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|