udp #5
@ -85,14 +85,14 @@ func buildTcp(p *proxy.Proxy, peer Peer) error {
|
||||
}
|
||||
|
||||
func buildUdp(p *proxy.Proxy, peer Peer) error {
|
||||
var c udp.Congestion
|
||||
var c func() udp.Congestion
|
||||
switch peer.Congestion {
|
||||
case "None":
|
||||
c = congestion.NewNone()
|
||||
c = func() udp.Congestion {return congestion.NewNone()}
|
||||
default:
|
||||
fallthrough
|
||||
case "NewReno":
|
||||
c = congestion.NewNewReno()
|
||||
c = func() udp.Congestion {return congestion.NewNewReno()}
|
||||
}
|
||||
|
||||
if peer.RemoteHost != "" {
|
||||
@ -101,7 +101,7 @@ func buildUdp(p *proxy.Proxy, peer Peer) error {
|
||||
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
||||
UselessMac{},
|
||||
UselessMac{},
|
||||
c,
|
||||
c(),
|
||||
time.Duration(peer.KeepAlive)*time.Second,
|
||||
)
|
||||
|
||||
|
@ -1,67 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import "time"
|
||||
|
||||
type MockPerfectBiConn struct {
|
||||
directionA chan byte
|
||||
directionB chan byte
|
||||
}
|
||||
|
||||
func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn {
|
||||
return MockPerfectBiConn{
|
||||
directionA: make(chan byte, bufSize),
|
||||
directionB: make(chan byte, bufSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiConn) SideA() MockPerfectConn {
|
||||
return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiConn) SideB() MockPerfectConn {
|
||||
return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||
}
|
||||
|
||||
type MockPerfectConn struct {
|
||||
inbound chan byte
|
||||
outbound chan byte
|
||||
}
|
||||
|
||||
func (c MockPerfectConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c MockPerfectConn) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
if i == 0 {
|
||||
p[i] = <-c.inbound
|
||||
} else {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
p[i] = b
|
||||
default:
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectConn) Write(p []byte) (n int, err error) {
|
||||
for _, b := range p {
|
||||
c.outbound <- b
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
p[i] = b
|
||||
default:
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package mocks
|
||||
|
||||
import "mpbl3p/shared"
|
||||
import (
|
||||
"mpbl3p/shared"
|
||||
)
|
||||
|
||||
type AlmostUselessMac struct{}
|
||||
|
||||
|
62
mocks/packetconn.go
Normal file
62
mocks/packetconn.go
Normal file
@ -0,0 +1,62 @@
|
||||
package mocks
|
||||
|
||||
import "net"
|
||||
|
||||
type MockPerfectBiPacketConn struct {
|
||||
directionA chan []byte
|
||||
directionB chan []byte
|
||||
}
|
||||
|
||||
func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn {
|
||||
return MockPerfectBiPacketConn{
|
||||
directionA: make(chan []byte, bufSize),
|
||||
directionB: make(chan []byte, bufSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn {
|
||||
return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn {
|
||||
return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||
}
|
||||
|
||||
type MockPerfectPacketConn struct {
|
||||
inbound chan []byte
|
||||
outbound chan []byte
|
||||
}
|
||||
|
||||
func (c MockPerfectPacketConn) Write(b []byte) (int, error) {
|
||||
c.outbound <- b
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
c.outbound <- b
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectPacketConn) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
}
|
||||
|
||||
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||
p := <-c.inbound
|
||||
return copy(b, p), &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
return copy(p, b), nil
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
}
|
95
mocks/streamconn.go
Normal file
95
mocks/streamconn.go
Normal file
@ -0,0 +1,95 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockPerfectBiStreamConn struct {
|
||||
directionA chan byte
|
||||
directionB chan byte
|
||||
}
|
||||
|
||||
func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn {
|
||||
return MockPerfectBiStreamConn{
|
||||
directionA: make(chan byte, bufSize),
|
||||
directionB: make(chan byte, bufSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn {
|
||||
return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||
}
|
||||
|
||||
func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn {
|
||||
return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||
}
|
||||
|
||||
type MockPerfectStreamConn struct {
|
||||
inbound chan byte
|
||||
outbound chan byte
|
||||
}
|
||||
|
||||
type Conn interface {
|
||||
Read(b []byte) (n int, err error)
|
||||
Write(b []byte) (n int, err error)
|
||||
SetWriteDeadline(time.Time) error
|
||||
|
||||
// For printing
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
if i == 0 {
|
||||
p[i] = <-c.inbound
|
||||
} else {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
p[i] = b
|
||||
default:
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) {
|
||||
for _, b := range p {
|
||||
c.outbound <- b
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only used for printing flow information
|
||||
func (c MockPerfectStreamConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 499,
|
||||
}
|
||||
}
|
||||
|
||||
func (c MockPerfectStreamConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 500,
|
||||
}
|
||||
}
|
||||
|
||||
func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
p[i] = b
|
||||
default:
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
@ -4,31 +4,90 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"mpbl3p/mocks"
|
||||
"mpbl3p/shared"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPacket_Marshal(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := NewPacket(testContent)
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
testPacket := NewSimplePacket(testContent)
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
marshalled := testPacket.Marshal(testMac)
|
||||
marshalled := testPacket.Marshal()
|
||||
|
||||
assert.Len(t, marshalled, len(testContent)+8+4)
|
||||
assert.Len(t, marshalled, len(testContent)+8)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalPacket(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := NewPacket(testContent)
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
testMarshalled := testPacket.Marshal(testMac)
|
||||
testPacket := NewSimplePacket(testContent)
|
||||
testMarshalled := testPacket.Marshal()
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
p, err := UnmarshalSimplePacket(testMarshalled, testMac)
|
||||
p, err := UnmarshalSimplePacket(testMarshalled)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, p.Marshal(), len(testContent))
|
||||
assert.Len(t, p.Contents(), len(testContent))
|
||||
})
|
||||
|
||||
t.Run("Contents", func(t *testing.T) {
|
||||
p, err := UnmarshalSimplePacket(testMarshalled)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, p.Contents(), testContent)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppendMac(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
testPacket := NewSimplePacket(testContent)
|
||||
testMarshalled := testPacket.Marshal()
|
||||
|
||||
appended := AppendMac(testMarshalled, testMac)
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
assert.Len(t, appended, len(testMarshalled)+4)
|
||||
})
|
||||
|
||||
t.Run("Mac", func(t *testing.T) {
|
||||
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):])
|
||||
})
|
||||
|
||||
t.Run("Original", func(t *testing.T) {
|
||||
assert.Equal(t, testMarshalled, appended[:len(testMarshalled)])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStripMac(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
testPacket := NewSimplePacket(testContent)
|
||||
testMarshalled := testPacket.Marshal()
|
||||
|
||||
appended := AppendMac(testMarshalled, testMac)
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
cut, err := StripMac(appended, testMac)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, cut, len(testMarshalled))
|
||||
})
|
||||
|
||||
t.Run("IncorrectMac", func(t *testing.T) {
|
||||
badMac := make([]byte, len(testMarshalled)+4)
|
||||
copy(badMac, testMarshalled)
|
||||
copy(badMac[:len(testMarshalled)], "dcba")
|
||||
_, err := StripMac(badMac, testMac)
|
||||
|
||||
assert.Error(t, err, shared.ErrBadChecksum)
|
||||
})
|
||||
|
||||
t.Run("Original", func(t *testing.T) {
|
||||
cut, err := StripMac(appended, testMac)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, testMarshalled, cut)
|
||||
})
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/go-playground/assert/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"mpbl3p/mocks"
|
||||
"mpbl3p/proxy"
|
||||
@ -11,11 +11,11 @@ import (
|
||||
|
||||
func TestFlow_Consume(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := proxy.NewPacket(testContent)
|
||||
testPacket := proxy.NewSimplePacket(testContent)
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
testConn := mocks.NewMockPerfectBiConn(100)
|
||||
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||
|
||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||
|
||||
@ -39,7 +39,7 @@ func TestFlow_Produce(t *testing.T) {
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
testConn := mocks.NewMockPerfectBiConn(100)
|
||||
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||
|
||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||
|
||||
@ -48,11 +48,11 @@ func TestFlow_Produce(t *testing.T) {
|
||||
|
||||
p, err := flowA.Produce(testMac)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, len(testContent), len(p.Marshal()))
|
||||
assert.Equal(t, len(testContent), len(p.Contents()))
|
||||
})
|
||||
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
testConn := mocks.NewMockPerfectBiConn(100)
|
||||
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||
|
||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||
|
||||
@ -61,6 +61,6 @@ func TestFlow_Produce(t *testing.T) {
|
||||
|
||||
p, err := flowA.Produce(testMac)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, testContent, string(p.Marshal()))
|
||||
assert.Equal(t, testContent, string(p.Contents()))
|
||||
})
|
||||
}
|
||||
|
@ -13,6 +13,5 @@ type Congestion interface {
|
||||
NextNack() uint32
|
||||
|
||||
AwaitEarlyUpdate(keepalive time.Duration) uint32
|
||||
AwaitAck(timeout time.Duration) bool
|
||||
Reset()
|
||||
}
|
||||
|
@ -16,8 +16,9 @@ type NewReno struct {
|
||||
sequence chan uint32
|
||||
keepalive chan bool
|
||||
|
||||
outboundTimes map[uint32]time.Time
|
||||
inboundTimes map[uint32]time.Time
|
||||
outboundTimes, inboundTimes map[uint32]time.Time
|
||||
outboundTimesLock sync.Mutex
|
||||
inboundTimesLock sync.RWMutex
|
||||
|
||||
ack, lastAck uint32
|
||||
nack, lastNack uint32
|
||||
@ -34,8 +35,7 @@ type NewReno struct {
|
||||
hasAcked bool
|
||||
|
||||
acksToSend utils.Uint32Heap
|
||||
|
||||
mu sync.Mutex
|
||||
acksToSendLock sync.Mutex
|
||||
}
|
||||
|
||||
func (c *NewReno) String() string {
|
||||
@ -82,12 +82,16 @@ func (c *NewReno) Reset() {
|
||||
|
||||
// It is assumed that ReceivedAck will only be called by one thread
|
||||
func (c *NewReno) ReceivedAck(ack uint32) {
|
||||
c.outboundTimesLock.Lock()
|
||||
defer c.outboundTimesLock.Unlock()
|
||||
|
||||
log.Printf("ack received for %d", ack)
|
||||
c.hasAcked = true
|
||||
|
||||
// RTT
|
||||
// Update using an exponential average
|
||||
rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds()
|
||||
|
||||
delete(c.outboundTimes, ack)
|
||||
c.rtt = c.rtt*(1-RttExponentialFactor) + rtt*RttExponentialFactor
|
||||
|
||||
@ -128,22 +132,26 @@ func (c *NewReno) ReceivedPacket(seq uint32) {
|
||||
|
||||
c.inboundTimes[seq] = time.Now()
|
||||
|
||||
c.mu.Lock()
|
||||
c.acksToSendLock.Lock()
|
||||
c.acksToSend.Insert(seq)
|
||||
c.mu.Unlock()
|
||||
c.acksToSendLock.Unlock()
|
||||
|
||||
c.updateAckNack()
|
||||
}
|
||||
|
||||
func (c *NewReno) updateAckNack() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.acksToSendLock.Lock()
|
||||
defer c.acksToSendLock.Unlock()
|
||||
|
||||
c.inboundTimesLock.Lock()
|
||||
defer c.inboundTimesLock.Unlock()
|
||||
|
||||
findAck := func(start uint32) uint32 {
|
||||
ack := start
|
||||
for len(c.acksToSend) > 0 {
|
||||
if a, _ := c.acksToSend.Peek(); a == ack+1 {
|
||||
ack, _ = c.acksToSend.Extract()
|
||||
delete(c.inboundTimes, ack)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
@ -169,6 +177,9 @@ func (c *NewReno) updateAckNack() {
|
||||
}
|
||||
|
||||
func (c *NewReno) Sequence() uint32 {
|
||||
c.outboundTimesLock.Lock()
|
||||
defer c.outboundTimesLock.Unlock()
|
||||
|
||||
for c.inFlight >= c.windowSize {
|
||||
<-c.ackNotifier
|
||||
}
|
||||
@ -213,11 +224,3 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NewReno) AwaitAck(timeout time.Duration) bool {
|
||||
if c.hasAcked {
|
||||
return true
|
||||
}
|
||||
time.Sleep(timeout)
|
||||
return c.hasAcked
|
||||
}
|
||||
|
@ -38,7 +38,6 @@ func (c *None) ReceivedPacket(uint32) {}
|
||||
func (c *None) ReceivedAck(uint32) {}
|
||||
func (c *None) ReceivedNack(uint32) {}
|
||||
func (c *None) Reset() {}
|
||||
func (c *None) AwaitAck(time.Duration) bool { return true }
|
||||
func (c *None) NextNack() uint32 { return 0 }
|
||||
func (c *None) NextAck() uint32 { return 0 }
|
||||
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }
|
||||
|
62
udp/flow.go
62
udp/flow.go
@ -82,6 +82,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
|
||||
|
||||
func (f *InitiatedFlow) Reconnect() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.isAlive {
|
||||
return nil
|
||||
@ -109,7 +110,6 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
go func() {
|
||||
seq := f.congestion.Sequence()
|
||||
|
||||
defer f.mu.Unlock()
|
||||
for !f.isAlive {
|
||||
p := Packet{
|
||||
ack: 0,
|
||||
@ -119,31 +119,6 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
}
|
||||
|
||||
_ = f.sendPacket(p, f.g)
|
||||
|
||||
if f.congestion.AwaitAck(1 * time.Second) {
|
||||
f.isAlive = true
|
||||
f.startup = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for f.startup || f.isAlive {
|
||||
func() {
|
||||
if f.isAlive {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
}
|
||||
|
||||
buf := make([]byte, 6000)
|
||||
n, _, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
time.Sleep(1)
|
||||
} else {
|
||||
f.handleDatagram(buf[:n])
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
@ -152,6 +127,30 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
}()
|
||||
go f.earlyUpdateLoop(f.g, f.keepalive)
|
||||
|
||||
if err := f.acceptPacket(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.isAlive = true
|
||||
f.startup = false
|
||||
|
||||
go func() {
|
||||
lockedAccept := func() {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if err := f.acceptPacket(conn); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
for f.isAlive {
|
||||
log.Println("alive and listening for packets")
|
||||
lockedAccept()
|
||||
}
|
||||
log.Println("no longer alive")
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -273,3 +272,14 @@ func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
|
||||
_ = f.sendPacket(p, g)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) acceptPacket(c PacketConn) error {
|
||||
buf := make([]byte, 6000)
|
||||
n, _, err := c.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.handleDatagram(buf[:n])
|
||||
return nil
|
||||
}
|
||||
|
85
udp/flow_test.go
Normal file
85
udp/flow_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"mpbl3p/mocks"
|
||||
"mpbl3p/proxy"
|
||||
"mpbl3p/udp/congestion"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFlow_Consume(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := proxy.NewSimplePacket(testContent)
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||
|
||||
flowA := newFlow(congestion.NewNone(), testMac)
|
||||
|
||||
flowA.writer = testConn.SideB()
|
||||
flowA.isAlive = true
|
||||
|
||||
err := flowA.Consume(testPacket, testMac)
|
||||
require.Nil(t, err)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, _, err := testConn.SideA().ReadFromUDP(buf)
|
||||
require.Nil(t, err)
|
||||
|
||||
// 12 header, 8 timestamp, 4 MAC
|
||||
assert.Equal(t, len(testContent)+12+8+4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFlow_Produce(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := Packet{
|
||||
ack: 42,
|
||||
nack: 26,
|
||||
seq: 128,
|
||||
data: proxy.NewSimplePacket(testContent),
|
||||
}
|
||||
testMac := mocks.AlmostUselessMac{}
|
||||
|
||||
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||
|
||||
_, err := testConn.SideA().Write(testMarshalled)
|
||||
require.Nil(t, err)
|
||||
|
||||
flowA := newFlow(congestion.NewNone(), testMac)
|
||||
|
||||
flowA.writer = testConn.SideB()
|
||||
flowA.isAlive = true
|
||||
|
||||
go func() {
|
||||
err := flowA.acceptPacket(testConn.SideB())
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
p, err := flowA.Produce(testMac)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, p.Contents(), len(testContent))
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(500 * time.Millisecond)
|
||||
select {
|
||||
case <-done:
|
||||
case <-timer.C:
|
||||
fmt.Println("timed out")
|
||||
t.FailNow()
|
||||
}
|
||||
})
|
||||
}
|
@ -25,7 +25,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
||||
}
|
||||
}
|
||||
|
||||
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c Congestion) error {
|
||||
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c func() Congestion) error {
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -63,7 +63,7 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacG
|
||||
continue
|
||||
}
|
||||
|
||||
f := newFlow(c, v)
|
||||
f := newFlow(c(), v)
|
||||
|
||||
f.writer = pconn
|
||||
f.raddr = addr
|
||||
|
59
udp/packet_test.go
Normal file
59
udp/packet_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"mpbl3p/proxy"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPacket_Marshal(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := Packet{
|
||||
ack: 18,
|
||||
nack: 29,
|
||||
seq: 431,
|
||||
data: proxy.NewSimplePacket(testContent),
|
||||
}
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
marshalled := testPacket.Marshal()
|
||||
|
||||
// 12 header + 8 timestamp
|
||||
assert.Len(t, marshalled, len(testContent)+12+8)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalPacket(t *testing.T) {
|
||||
testContent := []byte("A test string is the content of this packet.")
|
||||
testPacket := Packet{
|
||||
ack: 18,
|
||||
nack: 29,
|
||||
seq: 431,
|
||||
data: proxy.NewSimplePacket(testContent),
|
||||
}
|
||||
testMarshalled := testPacket.Marshal()
|
||||
|
||||
t.Run("Length", func(t *testing.T) {
|
||||
p, err := UnmarshalPacket(testMarshalled)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, p.Contents(), len(testContent))
|
||||
})
|
||||
|
||||
t.Run("Contents", func(t *testing.T) {
|
||||
p, err := UnmarshalPacket(testMarshalled)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, p.Contents(), testContent)
|
||||
})
|
||||
|
||||
t.Run("Header", func(t *testing.T) {
|
||||
p, err := UnmarshalPacket(testMarshalled)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.Equal(t, p.ack, uint32(18))
|
||||
assert.Equal(t, p.nack, uint32(29))
|
||||
assert.Equal(t, p.seq, uint32(431))
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user