udp #5

Merged
JakeHillion merged 16 commits from udp into develop 2020-11-28 16:53:00 +00:00
14 changed files with 440 additions and 134 deletions
Showing only changes of commit ff4ce07b05 - Show all commits

View File

@ -85,14 +85,14 @@ func buildTcp(p *proxy.Proxy, peer Peer) error {
}
func buildUdp(p *proxy.Proxy, peer Peer) error {
var c udp.Congestion
var c func() udp.Congestion
switch peer.Congestion {
case "None":
c = congestion.NewNone()
c = func() udp.Congestion {return congestion.NewNone()}
default:
fallthrough
case "NewReno":
c = congestion.NewNewReno()
c = func() udp.Congestion {return congestion.NewNewReno()}
}
if peer.RemoteHost != "" {
@ -101,7 +101,7 @@ func buildUdp(p *proxy.Proxy, peer Peer) error {
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
UselessMac{},
UselessMac{},
c,
c(),
time.Duration(peer.KeepAlive)*time.Second,
)

View File

@ -1,67 +0,0 @@
package mocks
import "time"
type MockPerfectBiConn struct {
directionA chan byte
directionB chan byte
}
func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn {
return MockPerfectBiConn{
directionA: make(chan byte, bufSize),
directionB: make(chan byte, bufSize),
}
}
func (bc MockPerfectBiConn) SideA() MockPerfectConn {
return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiConn) SideB() MockPerfectConn {
return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectConn struct {
inbound chan byte
outbound chan byte
}
func (c MockPerfectConn) SetWriteDeadline(time.Time) error {
return nil
}
func (c MockPerfectConn) Read(p []byte) (n int, err error) {
for i := range p {
if i == 0 {
p[i] = <-c.inbound
} else {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
}
return len(p), nil
}
func (c MockPerfectConn) Write(p []byte) (n int, err error) {
for _, b := range p {
c.outbound <- b
}
return len(p), nil
}
func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) {
for i := range p {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
return len(p), nil
}

View File

@ -1,6 +1,8 @@
package mocks
import "mpbl3p/shared"
import (
"mpbl3p/shared"
)
type AlmostUselessMac struct{}

62
mocks/packetconn.go Normal file
View File

@ -0,0 +1,62 @@
package mocks
import "net"
type MockPerfectBiPacketConn struct {
directionA chan []byte
directionB chan []byte
}
func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn {
return MockPerfectBiPacketConn{
directionA: make(chan []byte, bufSize),
directionB: make(chan []byte, bufSize),
}
}
func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn {
return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn {
return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectPacketConn struct {
inbound chan []byte
outbound chan []byte
}
func (c MockPerfectPacketConn) Write(b []byte) (int, error) {
c.outbound <- b
return len(b), nil
}
func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
c.outbound <- b
return len(b), nil
}
func (c MockPerfectPacketConn) LocalAddr() net.Addr {
return &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
}
}
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
p := <-c.inbound
return copy(b, p), &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
}, nil
}
func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) {
select {
case b := <-c.inbound:
return copy(p, b), nil
default:
return 0, nil
}
}

95
mocks/streamconn.go Normal file
View File

@ -0,0 +1,95 @@
package mocks
import (
"net"
"time"
)
type MockPerfectBiStreamConn struct {
directionA chan byte
directionB chan byte
}
func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn {
return MockPerfectBiStreamConn{
directionA: make(chan byte, bufSize),
directionB: make(chan byte, bufSize),
}
}
func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn {
return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn {
return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectStreamConn struct {
inbound chan byte
outbound chan byte
}
type Conn interface {
Read(b []byte) (n int, err error)
Write(b []byte) (n int, err error)
SetWriteDeadline(time.Time) error
// For printing
LocalAddr() net.Addr
RemoteAddr() net.Addr
}
func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) {
for i := range p {
if i == 0 {
p[i] = <-c.inbound
} else {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
}
return len(p), nil
}
func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) {
for _, b := range p {
c.outbound <- b
}
return len(p), nil
}
func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// Only used for printing flow information
func (c MockPerfectStreamConn) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 499,
}
}
func (c MockPerfectStreamConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 500,
}
}
func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) {
for i := range p {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
return len(p), nil
}

View File

@ -4,31 +4,90 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/mocks"
"mpbl3p/shared"
"testing"
)
func TestPacket_Marshal(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := NewPacket(testContent)
testMac := mocks.AlmostUselessMac{}
testPacket := NewSimplePacket(testContent)
t.Run("Length", func(t *testing.T) {
marshalled := testPacket.Marshal(testMac)
marshalled := testPacket.Marshal()
assert.Len(t, marshalled, len(testContent)+8+4)
assert.Len(t, marshalled, len(testContent)+8)
})
}
func TestUnmarshalPacket(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := NewPacket(testContent)
testMac := mocks.AlmostUselessMac{}
testMarshalled := testPacket.Marshal(testMac)
testPacket := NewSimplePacket(testContent)
testMarshalled := testPacket.Marshal()
t.Run("Length", func(t *testing.T) {
p, err := UnmarshalSimplePacket(testMarshalled, testMac)
p, err := UnmarshalSimplePacket(testMarshalled)
require.Nil(t, err)
assert.Len(t, p.Marshal(), len(testContent))
assert.Len(t, p.Contents(), len(testContent))
})
t.Run("Contents", func(t *testing.T) {
p, err := UnmarshalSimplePacket(testMarshalled)
require.Nil(t, err)
assert.Equal(t, p.Contents(), testContent)
})
}
func TestAppendMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac{}
testPacket := NewSimplePacket(testContent)
testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) {
assert.Len(t, appended, len(testMarshalled)+4)
})
t.Run("Mac", func(t *testing.T) {
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):])
})
t.Run("Original", func(t *testing.T) {
assert.Equal(t, testMarshalled, appended[:len(testMarshalled)])
})
}
func TestStripMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac{}
testPacket := NewSimplePacket(testContent)
testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) {
cut, err := StripMac(appended, testMac)
require.Nil(t, err)
assert.Len(t, cut, len(testMarshalled))
})
t.Run("IncorrectMac", func(t *testing.T) {
badMac := make([]byte, len(testMarshalled)+4)
copy(badMac, testMarshalled)
copy(badMac[:len(testMarshalled)], "dcba")
_, err := StripMac(badMac, testMac)
assert.Error(t, err, shared.ErrBadChecksum)
})
t.Run("Original", func(t *testing.T) {
cut, err := StripMac(appended, testMac)
require.Nil(t, err)
assert.Equal(t, testMarshalled, cut)
})
}

View File

@ -2,7 +2,7 @@ package tcp
import (
"encoding/binary"
"github.com/go-playground/assert/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/mocks"
"mpbl3p/proxy"
@ -11,11 +11,11 @@ import (
func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.NewPacket(testContent)
testPacket := proxy.NewSimplePacket(testContent)
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100)
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true}
@ -39,7 +39,7 @@ func TestFlow_Produce(t *testing.T) {
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100)
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true}
@ -48,11 +48,11 @@ func TestFlow_Produce(t *testing.T) {
p, err := flowA.Produce(testMac)
require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Marshal()))
assert.Equal(t, len(testContent), len(p.Contents()))
})
t.Run("Value", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100)
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true}
@ -61,6 +61,6 @@ func TestFlow_Produce(t *testing.T) {
p, err := flowA.Produce(testMac)
require.Nil(t, err)
assert.Equal(t, testContent, string(p.Marshal()))
assert.Equal(t, testContent, string(p.Contents()))
})
}

View File

@ -13,6 +13,5 @@ type Congestion interface {
NextNack() uint32
AwaitEarlyUpdate(keepalive time.Duration) uint32
AwaitAck(timeout time.Duration) bool
Reset()
}

View File

@ -16,8 +16,9 @@ type NewReno struct {
sequence chan uint32
keepalive chan bool
outboundTimes map[uint32]time.Time
inboundTimes map[uint32]time.Time
outboundTimes, inboundTimes map[uint32]time.Time
outboundTimesLock sync.Mutex
inboundTimesLock sync.RWMutex
ack, lastAck uint32
nack, lastNack uint32
@ -34,8 +35,7 @@ type NewReno struct {
hasAcked bool
acksToSend utils.Uint32Heap
mu sync.Mutex
acksToSendLock sync.Mutex
}
func (c *NewReno) String() string {
@ -82,12 +82,16 @@ func (c *NewReno) Reset() {
// It is assumed that ReceivedAck will only be called by one thread
func (c *NewReno) ReceivedAck(ack uint32) {
c.outboundTimesLock.Lock()
defer c.outboundTimesLock.Unlock()
log.Printf("ack received for %d", ack)
c.hasAcked = true
// RTT
// Update using an exponential average
rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds()
delete(c.outboundTimes, ack)
c.rtt = c.rtt*(1-RttExponentialFactor) + rtt*RttExponentialFactor
@ -128,22 +132,26 @@ func (c *NewReno) ReceivedPacket(seq uint32) {
c.inboundTimes[seq] = time.Now()
c.mu.Lock()
c.acksToSendLock.Lock()
c.acksToSend.Insert(seq)
c.mu.Unlock()
c.acksToSendLock.Unlock()
c.updateAckNack()
}
func (c *NewReno) updateAckNack() {
c.mu.Lock()
defer c.mu.Unlock()
c.acksToSendLock.Lock()
defer c.acksToSendLock.Unlock()
c.inboundTimesLock.Lock()
defer c.inboundTimesLock.Unlock()
findAck := func(start uint32) uint32 {
ack := start
for len(c.acksToSend) > 0 {
if a, _ := c.acksToSend.Peek(); a == ack+1 {
ack, _ = c.acksToSend.Extract()
delete(c.inboundTimes, ack)
} else {
break
}
@ -169,6 +177,9 @@ func (c *NewReno) updateAckNack() {
}
func (c *NewReno) Sequence() uint32 {
c.outboundTimesLock.Lock()
defer c.outboundTimesLock.Unlock()
for c.inFlight >= c.windowSize {
<-c.ackNotifier
}
@ -213,11 +224,3 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
}
}
}
func (c *NewReno) AwaitAck(timeout time.Duration) bool {
if c.hasAcked {
return true
}
time.Sleep(timeout)
return c.hasAcked
}

View File

@ -38,7 +38,6 @@ func (c *None) ReceivedPacket(uint32) {}
func (c *None) ReceivedAck(uint32) {}
func (c *None) ReceivedNack(uint32) {}
func (c *None) Reset() {}
func (c *None) AwaitAck(time.Duration) bool { return true }
func (c *None) NextNack() uint32 { return 0 }
func (c *None) NextAck() uint32 { return 0 }
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }

View File

@ -82,6 +82,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
func (f *InitiatedFlow) Reconnect() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
@ -109,7 +110,6 @@ func (f *InitiatedFlow) Reconnect() error {
go func() {
seq := f.congestion.Sequence()
defer f.mu.Unlock()
for !f.isAlive {
p := Packet{
ack: 0,
@ -119,31 +119,6 @@ func (f *InitiatedFlow) Reconnect() error {
}
_ = f.sendPacket(p, f.g)
if f.congestion.AwaitAck(1 * time.Second) {
f.isAlive = true
f.startup = false
}
}
}()
go func() {
for f.startup || f.isAlive {
func() {
if f.isAlive {
f.mu.RLock()
defer f.mu.RUnlock()
}
buf := make([]byte, 6000)
n, _, err := conn.ReadFromUDP(buf)
if err != nil {
log.Println(err)
time.Sleep(1)
} else {
f.handleDatagram(buf[:n])
}
}()
}
}()
@ -152,6 +127,30 @@ func (f *InitiatedFlow) Reconnect() error {
}()
go f.earlyUpdateLoop(f.g, f.keepalive)
if err := f.acceptPacket(conn); err != nil {
return err
}
f.isAlive = true
f.startup = false
go func() {
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
if err := f.acceptPacket(conn); err != nil {
log.Println(err)
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
return nil
}
@ -273,3 +272,14 @@ func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
_ = f.sendPacket(p, g)
}
}
func (f *Flow) acceptPacket(c PacketConn) error {
buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf)
if err != nil {
return err
}
f.handleDatagram(buf[:n])
return nil
}

85
udp/flow_test.go Normal file
View File

@ -0,0 +1,85 @@
package udp
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/mocks"
"mpbl3p/proxy"
"mpbl3p/udp/congestion"
"testing"
"time"
)
func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.NewSimplePacket(testContent)
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+8+4, n)
})
}
func TestFlow_Produce(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 42,
nack: 26,
seq: 128,
data: proxy.NewSimplePacket(testContent),
}
testMac := mocks.AlmostUselessMac{}
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
t.Run("Length", func(t *testing.T) {
done := make(chan struct{})
go func() {
testConn := mocks.NewMockPerfectBiPacketConn(10)
_, err := testConn.SideA().Write(testMarshalled)
require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
go func() {
err := flowA.acceptPacket(testConn.SideB())
assert.Nil(t, err)
}()
p, err := flowA.Produce(testMac)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
done <- struct{}{}
}()
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-done:
case <-timer.C:
fmt.Println("timed out")
t.FailNow()
}
})
}

View File

@ -25,7 +25,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
}
}
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c Congestion) error {
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c func() Congestion) error {
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
@ -63,7 +63,7 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacG
continue
}
f := newFlow(c, v)
f := newFlow(c(), v)
f.writer = pconn
f.raddr = addr

59
udp/packet_test.go Normal file
View File

@ -0,0 +1,59 @@
package udp
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/proxy"
"testing"
)
func TestPacket_Marshal(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 18,
nack: 29,
seq: 431,
data: proxy.NewSimplePacket(testContent),
}
t.Run("Length", func(t *testing.T) {
marshalled := testPacket.Marshal()
// 12 header + 8 timestamp
assert.Len(t, marshalled, len(testContent)+12+8)
})
}
func TestUnmarshalPacket(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 18,
nack: 29,
seq: 431,
data: proxy.NewSimplePacket(testContent),
}
testMarshalled := testPacket.Marshal()
t.Run("Length", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
})
t.Run("Contents", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Equal(t, p.Contents(), testContent)
})
t.Run("Header", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Equal(t, p.ack, uint32(18))
assert.Equal(t, p.nack, uint32(29))
assert.Equal(t, p.seq, uint32(431))
})
}