merge develop into master #21
4
Makefile
4
Makefile
@ -1,7 +1,7 @@
|
|||||||
manual:
|
manual:
|
||||||
docker run --rm -v /tmp:/tmp -v ${PWD}:/app -w /app golang:1.15-buster go build -o /tmp/mpbl3p
|
docker run --rm -v /tmp:/tmp -v ${PWD}:/app -w /app golang:1.15-buster go build -o /tmp/mpbl3p
|
||||||
rsync -p /tmp/mpbl3p 10.21.10.3:
|
rsync -p /tmp/mpbl3p 10.21.12.101:
|
||||||
rsync -p /tmp/mpbl3p 10.21.10.4:
|
rsync -p /tmp/mpbl3p 10.21.12.102:
|
||||||
|
|
||||||
manual-bsd:
|
manual-bsd:
|
||||||
GOOS=freebsd go build -o /tmp/mpbl3p
|
GOOS=freebsd go build -o /tmp/mpbl3p
|
||||||
|
@ -30,6 +30,7 @@ component parts, or incorporated into the main application.
|
|||||||
|
|
||||||
# IPv4 Forwarding
|
# IPv4 Forwarding
|
||||||
sysctl -w net.ipv4.ip_forward=1
|
sysctl -w net.ipv4.ip_forward=1
|
||||||
|
sysctl -w net.ipv4.conf.eth0.proxy_arp=1
|
||||||
|
|
||||||
# Tunnel addr/up
|
# Tunnel addr/up
|
||||||
ip addr add 172.19.152.2/31 dev nc0
|
ip addr add 172.19.152.2/31 dev nc0
|
||||||
@ -84,4 +85,4 @@ component parts, or incorporated into the main application.
|
|||||||
|
|
||||||
#### Client
|
#### Client
|
||||||
|
|
||||||
No configuration needed. Simply set the IP to that of the remote server/32 with no gateway.
|
No configuration needed. Simply set the IP to that of the remote server/32 with a gateway of 192.168.1.1.
|
||||||
|
@ -5,6 +5,9 @@ import (
|
|||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"mpbl3p/tcp"
|
"mpbl3p/tcp"
|
||||||
"mpbl3p/tun"
|
"mpbl3p/tun"
|
||||||
|
"mpbl3p/udp"
|
||||||
|
"mpbl3p/udp/congestion"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Delete this code as soon as an alternative is available
|
// TODO: Delete this code as soon as an alternative is available
|
||||||
@ -45,6 +48,11 @@ func (c Configuration) Build() (*proxy.Proxy, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
case "UDP":
|
||||||
|
err := buildUdp(p, peer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,10 +66,14 @@ func buildTcp(p *proxy.Proxy, peer Peer) error {
|
|||||||
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
p.AddConsumer(f)
|
p.AddConsumer(f)
|
||||||
p.AddProducer(f, UselessMac{})
|
p.AddProducer(f, UselessMac{})
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{})
|
err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{})
|
||||||
@ -71,3 +83,48 @@ func buildTcp(p *proxy.Proxy, peer Peer) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildUdp(p *proxy.Proxy, peer Peer) error {
|
||||||
|
var c func() udp.Congestion
|
||||||
|
switch peer.Congestion {
|
||||||
|
case "None":
|
||||||
|
c = func() udp.Congestion {return congestion.NewNone()}
|
||||||
|
default:
|
||||||
|
fallthrough
|
||||||
|
case "NewReno":
|
||||||
|
c = func() udp.Congestion {return congestion.NewNewReno()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.RemoteHost != "" {
|
||||||
|
f, err := udp.InitiateFlow(
|
||||||
|
fmt.Sprintf("%s:", peer.LocalHost),
|
||||||
|
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
||||||
|
UselessMac{},
|
||||||
|
UselessMac{},
|
||||||
|
c(),
|
||||||
|
time.Duration(peer.KeepAlive)*time.Second,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.AddConsumer(f)
|
||||||
|
p.AddProducer(f, UselessMac{})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := udp.NewListener(
|
||||||
|
p,
|
||||||
|
fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort),
|
||||||
|
UselessMac{},
|
||||||
|
UselessMac{},
|
||||||
|
c,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -16,7 +16,7 @@ type Host struct {
|
|||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
PublicKey string `validate:"required"`
|
PublicKey string `validate:"required"`
|
||||||
Method string `validate:"oneof=TCP"`
|
Method string `validate:"oneof=TCP UDP"`
|
||||||
|
|
||||||
LocalHost string `validate:"omitempty,ip"`
|
LocalHost string `validate:"omitempty,ip"`
|
||||||
LocalPort uint `validate:"max=65535"`
|
LocalPort uint `validate:"max=65535"`
|
||||||
@ -24,6 +24,8 @@ type Peer struct {
|
|||||||
RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"`
|
RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"`
|
||||||
RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"`
|
RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"`
|
||||||
|
|
||||||
|
Congestion string `validate:"oneof=NewReno None"`
|
||||||
|
|
||||||
KeepAlive uint
|
KeepAlive uint
|
||||||
Timeout uint
|
Timeout uint
|
||||||
RetryWait uint
|
RetryWait uint
|
||||||
|
1
go.mod
1
go.mod
@ -3,7 +3,6 @@ module mpbl3p
|
|||||||
go 1.15
|
go 1.15
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-playground/assert/v2 v2.0.1
|
|
||||||
github.com/go-playground/validator/v10 v10.4.1
|
github.com/go-playground/validator/v10 v10.4.1
|
||||||
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab
|
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab
|
||||||
github.com/smartystreets/goconvey v1.6.4 // indirect
|
github.com/smartystreets/goconvey v1.6.4 // indirect
|
||||||
|
9
main.go
9
main.go
@ -13,7 +13,14 @@ func main() {
|
|||||||
|
|
||||||
log.Println("loading config...")
|
log.Println("loading config...")
|
||||||
|
|
||||||
c, err := config.LoadConfig("config.ini")
|
var configLoc string
|
||||||
|
if v, ok := os.LookupEnv("CONFIG_LOC"); ok {
|
||||||
|
configLoc = v
|
||||||
|
} else {
|
||||||
|
configLoc = "config.ini"
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := config.LoadConfig(configLoc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -1,67 +0,0 @@
|
|||||||
package mocks
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
type MockPerfectBiConn struct {
|
|
||||||
directionA chan byte
|
|
||||||
directionB chan byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn {
|
|
||||||
return MockPerfectBiConn{
|
|
||||||
directionA: make(chan byte, bufSize),
|
|
||||||
directionB: make(chan byte, bufSize),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc MockPerfectBiConn) SideA() MockPerfectConn {
|
|
||||||
return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc MockPerfectBiConn) SideB() MockPerfectConn {
|
|
||||||
return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA}
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockPerfectConn struct {
|
|
||||||
inbound chan byte
|
|
||||||
outbound chan byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) SetWriteDeadline(time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) Read(p []byte) (n int, err error) {
|
|
||||||
for i := range p {
|
|
||||||
if i == 0 {
|
|
||||||
p[i] = <-c.inbound
|
|
||||||
} else {
|
|
||||||
select {
|
|
||||||
case b := <-c.inbound:
|
|
||||||
p[i] = b
|
|
||||||
default:
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) Write(p []byte) (n int, err error) {
|
|
||||||
for _, b := range p {
|
|
||||||
c.outbound <- b
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) {
|
|
||||||
for i := range p {
|
|
||||||
select {
|
|
||||||
case b := <-c.inbound:
|
|
||||||
p[i] = b
|
|
||||||
default:
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
@ -1,6 +1,8 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import "mpbl3p/shared"
|
import (
|
||||||
|
"mpbl3p/shared"
|
||||||
|
)
|
||||||
|
|
||||||
type AlmostUselessMac struct{}
|
type AlmostUselessMac struct{}
|
||||||
|
|
||||||
|
62
mocks/packetconn.go
Normal file
62
mocks/packetconn.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package mocks
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type MockPerfectBiPacketConn struct {
|
||||||
|
directionA chan []byte
|
||||||
|
directionB chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn {
|
||||||
|
return MockPerfectBiPacketConn{
|
||||||
|
directionA: make(chan []byte, bufSize),
|
||||||
|
directionB: make(chan []byte, bufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn {
|
||||||
|
return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn {
|
||||||
|
return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPerfectPacketConn struct {
|
||||||
|
inbound chan []byte
|
||||||
|
outbound chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) Write(b []byte) (int, error) {
|
||||||
|
c.outbound <- b
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
c.outbound <- b
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) LocalAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||||
|
p := <-c.inbound
|
||||||
|
return copy(b, p), &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
return copy(p, b), nil
|
||||||
|
default:
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
95
mocks/streamconn.go
Normal file
95
mocks/streamconn.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockPerfectBiStreamConn struct {
|
||||||
|
directionA chan byte
|
||||||
|
directionB chan byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn {
|
||||||
|
return MockPerfectBiStreamConn{
|
||||||
|
directionA: make(chan byte, bufSize),
|
||||||
|
directionB: make(chan byte, bufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn {
|
||||||
|
return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn {
|
||||||
|
return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPerfectStreamConn struct {
|
||||||
|
inbound chan byte
|
||||||
|
outbound chan byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn interface {
|
||||||
|
Read(b []byte) (n int, err error)
|
||||||
|
Write(b []byte) (n int, err error)
|
||||||
|
SetWriteDeadline(time.Time) error
|
||||||
|
|
||||||
|
// For printing
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) {
|
||||||
|
for i := range p {
|
||||||
|
if i == 0 {
|
||||||
|
p[i] = <-c.inbound
|
||||||
|
} else {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
p[i] = b
|
||||||
|
default:
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) {
|
||||||
|
for _, b := range p {
|
||||||
|
c.outbound <- b
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only used for printing flow information
|
||||||
|
func (c MockPerfectStreamConn) LocalAddr() net.Addr {
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 499,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||||
|
for i := range p {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
p[i] = b
|
||||||
|
default:
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
@ -5,30 +5,27 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Packet struct {
|
type Packet interface {
|
||||||
|
Marshal() []byte
|
||||||
|
Contents() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimplePacket struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a packet from the raw data of an IP packet
|
// create a packet from the raw data of an IP packet
|
||||||
func NewPacket(data []byte) Packet {
|
func NewSimplePacket(data []byte) Packet {
|
||||||
return Packet{
|
return SimplePacket{
|
||||||
Data: data,
|
Data: data,
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rebuild a packet from the wrapped format
|
// rebuild a packet from the wrapped format
|
||||||
func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) {
|
func UnmarshalSimplePacket(data []byte) (SimplePacket, error) {
|
||||||
// the MAC is the last N bytes
|
p := SimplePacket{
|
||||||
data := raw[:len(raw)-verifier.CodeLength()]
|
|
||||||
sum := raw[len(raw)-verifier.CodeLength():]
|
|
||||||
|
|
||||||
if err := verifier.Verify(data, sum); err != nil {
|
|
||||||
return Packet{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Packet{
|
|
||||||
Data: data[:len(data)-8],
|
Data: data[:len(data)-8],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,22 +36,32 @@ func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get the raw data of the IP packet
|
// get the raw data of the IP packet
|
||||||
func (p Packet) Raw() []byte {
|
func (p SimplePacket) Marshal() []byte {
|
||||||
|
footer := make([]byte, 8)
|
||||||
|
|
||||||
|
unixTime := uint64(p.timestamp.Unix())
|
||||||
|
binary.LittleEndian.PutUint64(footer, unixTime)
|
||||||
|
|
||||||
|
return append(p.Data, footer...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p SimplePacket) Contents() []byte {
|
||||||
return p.Data
|
return p.Data
|
||||||
}
|
}
|
||||||
|
|
||||||
// produce the wrapped format of a packet
|
func AppendMac(b []byte, g MacGenerator) []byte {
|
||||||
func (p Packet) Marshal(generator MacGenerator) []byte {
|
mac := g.Generate(b)
|
||||||
// length of data + length of timestamp (8 byte) + length of checksum
|
b = append(b, mac...)
|
||||||
slice := make([]byte, len(p.Data)+8+generator.CodeLength())
|
return b
|
||||||
|
}
|
||||||
copy(slice, p.Data)
|
|
||||||
|
func StripMac(b []byte, v MacVerifier) ([]byte, error) {
|
||||||
unixTime := uint64(p.timestamp.Unix())
|
data := b[:len(b)-v.CodeLength()]
|
||||||
binary.LittleEndian.PutUint64(slice[len(p.Data):], unixTime)
|
sum := b[len(b)-v.CodeLength():]
|
||||||
|
|
||||||
mac := generator.Generate(slice)
|
if err := v.Verify(data, sum); err != nil {
|
||||||
copy(slice[len(p.Data)+8:], mac)
|
return nil, err
|
||||||
|
}
|
||||||
return slice
|
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
@ -4,31 +4,90 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"mpbl3p/mocks"
|
"mpbl3p/mocks"
|
||||||
|
"mpbl3p/shared"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPacket_Marshal(t *testing.T) {
|
func TestPacket_Marshal(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := NewPacket(testContent)
|
testPacket := NewSimplePacket(testContent)
|
||||||
testMac := mocks.AlmostUselessMac{}
|
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
marshalled := testPacket.Marshal(testMac)
|
marshalled := testPacket.Marshal()
|
||||||
|
|
||||||
assert.Len(t, marshalled, len(testContent)+8+4)
|
assert.Len(t, marshalled, len(testContent)+8)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPacket(t *testing.T) {
|
func TestUnmarshalPacket(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := NewPacket(testContent)
|
testPacket := NewSimplePacket(testContent)
|
||||||
testMac := mocks.AlmostUselessMac{}
|
testMarshalled := testPacket.Marshal()
|
||||||
testMarshalled := testPacket.Marshal(testMac)
|
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
p, err := UnmarshalPacket(testMarshalled, testMac)
|
p, err := UnmarshalSimplePacket(testMarshalled)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, p.Raw(), len(testContent))
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Contents", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalSimplePacket(testMarshalled)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, p.Contents(), testContent)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendMac(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
testPacket := NewSimplePacket(testContent)
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
appended := AppendMac(testMarshalled, testMac)
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
assert.Len(t, appended, len(testMarshalled)+4)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Mac", func(t *testing.T) {
|
||||||
|
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Original", func(t *testing.T) {
|
||||||
|
assert.Equal(t, testMarshalled, appended[:len(testMarshalled)])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripMac(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
testPacket := NewSimplePacket(testContent)
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
appended := AppendMac(testMarshalled, testMac)
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
cut, err := StripMac(appended, testMac)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Len(t, cut, len(testMarshalled))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IncorrectMac", func(t *testing.T) {
|
||||||
|
badMac := make([]byte, len(testMarshalled)+4)
|
||||||
|
copy(badMac, testMarshalled)
|
||||||
|
copy(badMac[:len(testMarshalled)], "dcba")
|
||||||
|
_, err := StripMac(badMac, testMac)
|
||||||
|
|
||||||
|
assert.Error(t, err, shared.ErrBadChecksum)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Original", func(t *testing.T) {
|
||||||
|
cut, err := StripMac(appended, testMac)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, testMarshalled, cut)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -75,13 +75,13 @@ func (p Proxy) AddConsumer(c Consumer) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect `%v`\n", c)
|
log.Printf("attempting to connect consumer `%v`\n", c)
|
||||||
err = c.(Reconnectable).Reconnect()
|
err = c.(Reconnectable).Reconnect()
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("connected `%v`\n", c)
|
log.Printf("connected consumer `%v`\n", c)
|
||||||
}
|
}
|
||||||
|
|
||||||
for c.IsAlive() {
|
for c.IsAlive() {
|
||||||
@ -92,7 +92,7 @@ func (p Proxy) AddConsumer(c Consumer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("closed connection `%v`\n", c)
|
log.Printf("closed consumer `%v`\n", c)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,13 +104,13 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect `%v`\n", pr)
|
log.Printf("attempting to connect producer `%v`\n", pr)
|
||||||
err = pr.(Reconnectable).Reconnect()
|
err = pr.(Reconnectable).Reconnect()
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("connected `%v`\n", pr)
|
log.Printf("connected producer `%v`\n", pr)
|
||||||
}
|
}
|
||||||
|
|
||||||
for pr.IsAlive() {
|
for pr.IsAlive() {
|
||||||
@ -123,6 +123,6 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("closed connection `%v`\n", pr)
|
log.Printf("closed producer `%v`\n", pr)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
50
tcp/flow.go
50
tcp/flow.go
@ -3,6 +3,7 @@ package tcp
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"mpbl3p/shared"
|
"mpbl3p/shared"
|
||||||
@ -17,6 +18,10 @@ type Conn interface {
|
|||||||
Read(b []byte) (n int, err error)
|
Read(b []byte) (n int, err error)
|
||||||
Write(b []byte) (n int, err error)
|
Write(b []byte) (n int, err error)
|
||||||
SetWriteDeadline(time.Time) error
|
SetWriteDeadline(time.Time) error
|
||||||
|
|
||||||
|
// For printing
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
RemoteAddr() net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type InitiatedFlow struct {
|
type InitiatedFlow struct {
|
||||||
@ -28,11 +33,19 @@ type InitiatedFlow struct {
|
|||||||
Flow
|
Flow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) String() string {
|
||||||
|
return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local, f.Remote)
|
||||||
|
}
|
||||||
|
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
conn Conn
|
conn Conn
|
||||||
isAlive bool
|
isAlive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f Flow) String() string {
|
||||||
|
return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
func InitiateFlow(local, remote string) (*InitiatedFlow, error) {
|
func InitiateFlow(local, remote string) (*InitiatedFlow, error) {
|
||||||
f := InitiatedFlow{
|
f := InitiatedFlow{
|
||||||
Local: local,
|
Local: local,
|
||||||
@ -75,10 +88,6 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) IsAlive() bool {
|
|
||||||
return f.isAlive
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
@ -86,12 +95,25 @@ func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
|||||||
return f.Flow.Consume(p, g)
|
return f.Flow.Consume(p, g)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) IsAlive() bool {
|
||||||
|
return f.isAlive
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) {
|
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return shared.ErrDeadConnection
|
return shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
data := p.Marshal(g)
|
marshalled := p.Marshal()
|
||||||
|
data := proxy.AppendMac(marshalled, g)
|
||||||
|
|
||||||
err = f.consumeMarshalled(data)
|
err = f.consumeMarshalled(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.isAlive = false
|
f.isAlive = false
|
||||||
@ -112,25 +134,23 @@ func (f *Flow) consumeMarshalled(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
|
||||||
f.mu.RLock()
|
|
||||||
defer f.mu.RUnlock()
|
|
||||||
|
|
||||||
return f.Flow.Produce(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return proxy.Packet{}, shared.ErrDeadConnection
|
return nil, shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := f.produceMarshalled()
|
data, err := f.produceMarshalled()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.isAlive = false
|
f.isAlive = false
|
||||||
return proxy.Packet{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy.UnmarshalPacket(data, v)
|
b, err := proxy.StripMac(data, v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxy.UnmarshalSimplePacket(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) produceMarshalled() ([]byte, error) {
|
func (f *Flow) produceMarshalled() ([]byte, error) {
|
||||||
|
@ -2,7 +2,7 @@ package tcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"github.com/go-playground/assert/v2"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"mpbl3p/mocks"
|
"mpbl3p/mocks"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
@ -11,11 +11,11 @@ import (
|
|||||||
|
|
||||||
func TestFlow_Consume(t *testing.T) {
|
func TestFlow_Consume(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := proxy.NewPacket(testContent)
|
testPacket := proxy.NewSimplePacket(testContent)
|
||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||||
|
|
||||||
@ -48,11 +48,11 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, len(testContent), len(p.Raw()))
|
assert.Equal(t, len(testContent), len(p.Contents()))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Value", func(t *testing.T) {
|
t.Run("Value", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
||||||
|
|
||||||
@ -61,6 +61,6 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, testContent, string(p.Raw()))
|
assert.Equal(t, testContent, string(p.Contents()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error {
|
|||||||
|
|
||||||
f := Flow{conn: conn, isAlive: true}
|
f := Flow{conn: conn, isAlive: true}
|
||||||
|
|
||||||
log.Printf("received new connection: %v\n", f)
|
log.Printf("received new tcp connection: %v\n", f)
|
||||||
|
|
||||||
p.AddConsumer(&f)
|
p.AddConsumer(&f)
|
||||||
p.AddProducer(&f, v)
|
p.AddProducer(&f, v)
|
||||||
|
@ -61,14 +61,14 @@ func (t *SourceSink) Source() (proxy.Packet, error) {
|
|||||||
|
|
||||||
read, err := t.tun.Read(buf)
|
read, err := t.tun.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return proxy.Packet{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if read == 0 {
|
if read == 0 {
|
||||||
return proxy.Packet{}, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy.NewPacket(buf[:read]), nil
|
return proxy.NewSimplePacket(buf[:read]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var good, bad float64
|
var good, bad float64
|
||||||
@ -79,7 +79,7 @@ func (t *SourceSink) Sink(packet proxy.Packet) error {
|
|||||||
t.upMu.Unlock()
|
t.upMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := t.tun.Write(packet.Raw())
|
_, err := t.tun.Write(packet.Contents())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case *os.PathError:
|
case *os.PathError:
|
||||||
|
17
udp/congestion.go
Normal file
17
udp/congestion.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type Congestion interface {
|
||||||
|
Sequence() uint32
|
||||||
|
ReceivedPacket(seq uint32)
|
||||||
|
|
||||||
|
ReceivedAck(uint32)
|
||||||
|
NextAck() uint32
|
||||||
|
|
||||||
|
ReceivedNack(uint32)
|
||||||
|
NextNack() uint32
|
||||||
|
|
||||||
|
AwaitEarlyUpdate(keepalive time.Duration) uint32
|
||||||
|
Reset()
|
||||||
|
}
|
226
udp/congestion/newreno.go
Normal file
226
udp/congestion/newreno.go
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"mpbl3p/utils"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const RttExponentialFactor = 0.1
|
||||||
|
|
||||||
|
type NewReno struct {
|
||||||
|
sequence chan uint32
|
||||||
|
keepalive chan bool
|
||||||
|
|
||||||
|
outboundTimes, inboundTimes map[uint32]time.Time
|
||||||
|
outboundTimesLock sync.Mutex
|
||||||
|
inboundTimesLock sync.RWMutex
|
||||||
|
|
||||||
|
ack, lastAck uint32
|
||||||
|
nack, lastNack uint32
|
||||||
|
|
||||||
|
slowStart bool
|
||||||
|
rtt float64
|
||||||
|
windowSize int32
|
||||||
|
windowCount int32
|
||||||
|
inFlight int32
|
||||||
|
|
||||||
|
ackNotifier chan struct{}
|
||||||
|
|
||||||
|
lastSent time.Time
|
||||||
|
hasAcked bool
|
||||||
|
|
||||||
|
acksToSend utils.Uint32Heap
|
||||||
|
acksToSendLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) String() string {
|
||||||
|
return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, c.inFlight, c.lastAck, c.lastNack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNewReno() *NewReno {
|
||||||
|
c := NewReno{
|
||||||
|
sequence: make(chan uint32),
|
||||||
|
ackNotifier: make(chan struct{}),
|
||||||
|
|
||||||
|
outboundTimes: make(map[uint32]time.Time),
|
||||||
|
inboundTimes: make(map[uint32]time.Time),
|
||||||
|
|
||||||
|
windowSize: 8,
|
||||||
|
rtt: (1 * time.Millisecond).Seconds(),
|
||||||
|
slowStart: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var s uint32
|
||||||
|
for {
|
||||||
|
if s == 0 {
|
||||||
|
s++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sequence <- s
|
||||||
|
s++
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) Reset() {
|
||||||
|
c.outboundTimes = make(map[uint32]time.Time)
|
||||||
|
c.inboundTimes = make(map[uint32]time.Time)
|
||||||
|
c.windowSize = 8
|
||||||
|
c.rtt = (1 * time.Millisecond).Seconds()
|
||||||
|
c.slowStart = true
|
||||||
|
c.hasAcked = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is assumed that ReceivedAck will only be called by one thread
|
||||||
|
func (c *NewReno) ReceivedAck(ack uint32) {
|
||||||
|
c.outboundTimesLock.Lock()
|
||||||
|
defer c.outboundTimesLock.Unlock()
|
||||||
|
|
||||||
|
log.Printf("ack received for %d", ack)
|
||||||
|
c.hasAcked = true
|
||||||
|
|
||||||
|
// RTT
|
||||||
|
// Update using an exponential average
|
||||||
|
rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds()
|
||||||
|
|
||||||
|
delete(c.outboundTimes, ack)
|
||||||
|
c.rtt = c.rtt*(1-RttExponentialFactor) + rtt*RttExponentialFactor
|
||||||
|
|
||||||
|
// Free Window
|
||||||
|
atomic.AddInt32(&c.inFlight, -1)
|
||||||
|
select {
|
||||||
|
case c.ackNotifier <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// GROW
|
||||||
|
// CASE: exponential. increase window size by one per ack
|
||||||
|
// CASE: standard. increase window size by one per window of acks
|
||||||
|
if c.slowStart {
|
||||||
|
atomic.AddInt32(&c.windowSize, 1)
|
||||||
|
} else {
|
||||||
|
c.windowCount++
|
||||||
|
if c.windowCount == c.windowSize {
|
||||||
|
c.windowCount = 0
|
||||||
|
atomic.AddInt32(&c.windowSize, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is assumed that ReceivedNack will only be called by one thread
|
||||||
|
func (c *NewReno) ReceivedNack(nack uint32) {
|
||||||
|
log.Printf("nack received for %d", nack)
|
||||||
|
|
||||||
|
// End slow start
|
||||||
|
c.slowStart = false
|
||||||
|
if s := c.windowSize; s > 1 {
|
||||||
|
atomic.StoreInt32(&c.windowSize, s/2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) ReceivedPacket(seq uint32) {
|
||||||
|
log.Printf("seq received for %d", seq)
|
||||||
|
|
||||||
|
c.inboundTimes[seq] = time.Now()
|
||||||
|
|
||||||
|
c.acksToSendLock.Lock()
|
||||||
|
c.acksToSend.Insert(seq)
|
||||||
|
c.acksToSendLock.Unlock()
|
||||||
|
|
||||||
|
c.updateAckNack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) updateAckNack() {
|
||||||
|
c.acksToSendLock.Lock()
|
||||||
|
defer c.acksToSendLock.Unlock()
|
||||||
|
|
||||||
|
c.inboundTimesLock.Lock()
|
||||||
|
defer c.inboundTimesLock.Unlock()
|
||||||
|
|
||||||
|
findAck := func(start uint32) uint32 {
|
||||||
|
ack := start
|
||||||
|
for len(c.acksToSend) > 0 {
|
||||||
|
if a, _ := c.acksToSend.Peek(); a == ack+1 {
|
||||||
|
ack, _ = c.acksToSend.Extract()
|
||||||
|
delete(c.inboundTimes, ack)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ack
|
||||||
|
}
|
||||||
|
|
||||||
|
ack := findAck(c.ack)
|
||||||
|
if ack == c.ack {
|
||||||
|
// check if there is a nack to send
|
||||||
|
// decide this based on whether there have been 3RTTs between the offset packet
|
||||||
|
if len(c.acksToSend) > 0 {
|
||||||
|
nextAck, _ := c.acksToSend.Peek()
|
||||||
|
if time.Now().Sub(c.inboundTimes[nextAck]).Seconds() > c.rtt*3 {
|
||||||
|
atomic.StoreUint32(&c.nack, nextAck-1)
|
||||||
|
ack, _ = c.acksToSend.Extract()
|
||||||
|
ack = findAck(ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.StoreUint32(&c.ack, ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) Sequence() uint32 {
|
||||||
|
c.outboundTimesLock.Lock()
|
||||||
|
defer c.outboundTimesLock.Unlock()
|
||||||
|
|
||||||
|
for c.inFlight >= c.windowSize {
|
||||||
|
<-c.ackNotifier
|
||||||
|
}
|
||||||
|
atomic.AddInt32(&c.inFlight, 1)
|
||||||
|
|
||||||
|
s := <-c.sequence
|
||||||
|
|
||||||
|
n := time.Now()
|
||||||
|
c.lastSent = n
|
||||||
|
c.outboundTimes[s] = n
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) NextAck() uint32 {
|
||||||
|
a := c.ack
|
||||||
|
c.lastAck = a
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) NextNack() uint32 {
|
||||||
|
n := c.nack
|
||||||
|
c.lastNack = n
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
||||||
|
for {
|
||||||
|
rtt := time.Duration(math.Round(c.rtt * float64(time.Second)))
|
||||||
|
time.Sleep(rtt)
|
||||||
|
|
||||||
|
c.updateAckNack()
|
||||||
|
|
||||||
|
// CASE 1: waiting ACKs or NACKs and no message sent in the last RTT
|
||||||
|
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt)) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CASE 3: No message sent within the keepalive time
|
||||||
|
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
|
||||||
|
return c.Sequence()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
44
udp/congestion/none.go
Normal file
44
udp/congestion/none.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type None struct {
|
||||||
|
sequence chan uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNone() *None {
|
||||||
|
c := None{
|
||||||
|
sequence: make(chan uint32),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var s uint32
|
||||||
|
for {
|
||||||
|
if s == 0 {
|
||||||
|
s++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sequence <- s
|
||||||
|
s++
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *None) String() string {
|
||||||
|
return fmt.Sprintf("{None}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *None) ReceivedPacket(uint32) {}
|
||||||
|
func (c *None) ReceivedAck(uint32) {}
|
||||||
|
func (c *None) ReceivedNack(uint32) {}
|
||||||
|
func (c *None) Reset() {}
|
||||||
|
func (c *None) NextNack() uint32 { return 0 }
|
||||||
|
func (c *None) NextAck() uint32 { return 0 }
|
||||||
|
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }
|
||||||
|
func (c *None) Sequence() uint32 { return <-c.sequence }
|
285
udp/flow.go
Normal file
285
udp/flow.go
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketWriter interface {
|
||||||
|
Write(b []byte) (int, error)
|
||||||
|
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketConn interface {
|
||||||
|
PacketWriter
|
||||||
|
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InitiatedFlow struct {
|
||||||
|
Local string
|
||||||
|
Remote string
|
||||||
|
|
||||||
|
g proxy.MacGenerator
|
||||||
|
keepalive time.Duration
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
Flow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) String() string {
|
||||||
|
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local, f.Remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Flow struct {
|
||||||
|
writer PacketWriter
|
||||||
|
raddr *net.UDPAddr
|
||||||
|
|
||||||
|
isAlive bool
|
||||||
|
startup bool
|
||||||
|
congestion Congestion
|
||||||
|
|
||||||
|
v proxy.MacVerifier
|
||||||
|
|
||||||
|
inboundDatagrams chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flow) String() string {
|
||||||
|
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitiateFlow(
|
||||||
|
local, remote string,
|
||||||
|
v proxy.MacVerifier,
|
||||||
|
g proxy.MacGenerator,
|
||||||
|
c Congestion,
|
||||||
|
keepalive time.Duration,
|
||||||
|
) (*InitiatedFlow, error) {
|
||||||
|
f := InitiatedFlow{
|
||||||
|
Local: local,
|
||||||
|
Remote: remote,
|
||||||
|
Flow: newFlow(c, v),
|
||||||
|
g: g,
|
||||||
|
keepalive: keepalive,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFlow(c Congestion, v proxy.MacVerifier) Flow {
|
||||||
|
return Flow{
|
||||||
|
inboundDatagrams: make(chan []byte),
|
||||||
|
congestion: c,
|
||||||
|
v: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Reconnect() error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
if f.isAlive {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr, err := net.ResolveUDPAddr("udp", f.Local)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.writer = conn
|
||||||
|
f.startup = true
|
||||||
|
|
||||||
|
// prod the connection once a second until we get an ack, then consider it alive
|
||||||
|
go func() {
|
||||||
|
seq := f.congestion.Sequence()
|
||||||
|
|
||||||
|
for !f.isAlive {
|
||||||
|
p := Packet{
|
||||||
|
ack: 0,
|
||||||
|
nack: 0,
|
||||||
|
seq: seq,
|
||||||
|
data: proxy.NewSimplePacket(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = f.sendPacket(p, f.g)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = f.produceInternal(f.v, false)
|
||||||
|
}()
|
||||||
|
go f.earlyUpdateLoop(f.g, f.keepalive)
|
||||||
|
|
||||||
|
if err := f.acceptPacket(conn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.isAlive = true
|
||||||
|
f.startup = false
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
lockedAccept := func() {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
if err := f.acceptPacket(conn); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for f.isAlive {
|
||||||
|
log.Println("alive and listening for packets")
|
||||||
|
lockedAccept()
|
||||||
|
}
|
||||||
|
log.Println("no longer alive")
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Consume(p, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) IsAlive() bool {
|
||||||
|
return f.isAlive
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
|
||||||
|
if !f.isAlive {
|
||||||
|
return shared.ErrDeadConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(f.congestion)
|
||||||
|
|
||||||
|
// Sequence is the congestion controllers opportunity to block
|
||||||
|
log.Println("awaiting sequence")
|
||||||
|
p := Packet{
|
||||||
|
seq: f.congestion.Sequence(),
|
||||||
|
data: pp,
|
||||||
|
}
|
||||||
|
log.Println("received sequence")
|
||||||
|
|
||||||
|
// Choose up to date ACK/NACK even after blocking
|
||||||
|
p.ack = f.congestion.NextAck()
|
||||||
|
p.nack = f.congestion.NextNack()
|
||||||
|
|
||||||
|
return f.sendPacket(p, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
if !f.isAlive {
|
||||||
|
return nil, shared.ErrDeadConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.produceInternal(v, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
||||||
|
for once := true; mustReturn || once; once = false {
|
||||||
|
log.Println(f.congestion)
|
||||||
|
|
||||||
|
b, err := proxy.StripMac(<-f.inboundDatagrams, v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := UnmarshalPacket(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// schedule an ack for this sequence number
|
||||||
|
if p.seq != 0 {
|
||||||
|
f.congestion.ReceivedPacket(p.seq)
|
||||||
|
}
|
||||||
|
// adjust our sending congestion control based on their acks
|
||||||
|
if p.ack != 0 {
|
||||||
|
f.congestion.ReceivedAck(p.ack)
|
||||||
|
}
|
||||||
|
// adjust our sending congestion control based on their nacks
|
||||||
|
if p.nack != 0 {
|
||||||
|
f.congestion.ReceivedNack(p.nack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 12 bytes for header + the MAC + a timestamp
|
||||||
|
if len(b) == 12+f.v.CodeLength()+8 {
|
||||||
|
log.Println("handled keepalive/ack only packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) handleDatagram(p []byte) {
|
||||||
|
f.inboundDatagrams <- p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
||||||
|
b := p.Marshal()
|
||||||
|
b = proxy.AppendMac(b, g)
|
||||||
|
|
||||||
|
if f.raddr == nil {
|
||||||
|
_, err := f.writer.Write(b)
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
_, err := f.writer.WriteToUDP(b, f.raddr)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
|
||||||
|
var err error
|
||||||
|
for !errors.Is(err, shared.ErrDeadConnection) {
|
||||||
|
seq := f.congestion.AwaitEarlyUpdate(keepalive)
|
||||||
|
p := Packet{
|
||||||
|
ack: f.congestion.NextAck(),
|
||||||
|
nack: f.congestion.NextNack(),
|
||||||
|
seq: seq,
|
||||||
|
data: proxy.NewSimplePacket(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = f.sendPacket(p, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) acceptPacket(c PacketConn) error {
|
||||||
|
buf := make([]byte, 6000)
|
||||||
|
n, _, err := c.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.handleDatagram(buf[:n])
|
||||||
|
return nil
|
||||||
|
}
|
85
udp/flow_test.go
Normal file
85
udp/flow_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"mpbl3p/mocks"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"mpbl3p/udp/congestion"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlow_Consume(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := proxy.NewSimplePacket(testContent)
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||||
|
|
||||||
|
flowA := newFlow(congestion.NewNone(), testMac)
|
||||||
|
|
||||||
|
flowA.writer = testConn.SideB()
|
||||||
|
flowA.isAlive = true
|
||||||
|
|
||||||
|
err := flowA.Consume(testPacket, testMac)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
n, _, err := testConn.SideA().ReadFromUDP(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// 12 header, 8 timestamp, 4 MAC
|
||||||
|
assert.Equal(t, len(testContent)+12+8+4, n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlow_Produce(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 42,
|
||||||
|
nack: 26,
|
||||||
|
seq: 128,
|
||||||
|
data: proxy.NewSimplePacket(testContent),
|
||||||
|
}
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
|
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||||
|
|
||||||
|
_, err := testConn.SideA().Write(testMarshalled)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
flowA := newFlow(congestion.NewNone(), testMac)
|
||||||
|
|
||||||
|
flowA.writer = testConn.SideB()
|
||||||
|
flowA.isAlive = true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := flowA.acceptPacket(testConn.SideB())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}()
|
||||||
|
p, err := flowA.Produce(testMac)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
timer := time.NewTimer(500 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-timer.C:
|
||||||
|
fmt.Println("timed out")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
88
udp/listener.go
Normal file
88
udp/listener.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ComparableUdpAddress struct {
|
||||||
|
IP [16]byte
|
||||||
|
Port int
|
||||||
|
Zone string
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
||||||
|
var ip [16]byte
|
||||||
|
for i, b := range []byte(address.IP) {
|
||||||
|
ip[i] = b
|
||||||
|
}
|
||||||
|
|
||||||
|
return ComparableUdpAddress{
|
||||||
|
IP: ip,
|
||||||
|
Port: address.Port,
|
||||||
|
Zone: address.Zone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c func() Congestion) error {
|
||||||
|
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pconn, err := net.ListenUDP("udp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pconn.SetWriteBuffer(0)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedConnections := make(map[ComparableUdpAddress]*Flow)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
buf := make([]byte, 6000)
|
||||||
|
|
||||||
|
log.Println("listening...")
|
||||||
|
n, addr, err := pconn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Println("listened")
|
||||||
|
|
||||||
|
raddr := fromUdpAddress(*addr)
|
||||||
|
if f, exists := receivedConnections[raddr]; exists {
|
||||||
|
log.Println("existing flow")
|
||||||
|
log.Println("handling...")
|
||||||
|
f.handleDatagram(buf[:n])
|
||||||
|
log.Println("handled")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
f := newFlow(c(), v)
|
||||||
|
|
||||||
|
f.writer = pconn
|
||||||
|
f.raddr = addr
|
||||||
|
f.isAlive = true
|
||||||
|
|
||||||
|
log.Printf("received new udp connection: %v\n", f)
|
||||||
|
|
||||||
|
go f.earlyUpdateLoop(g, 0)
|
||||||
|
|
||||||
|
receivedConnections[raddr] = &f
|
||||||
|
|
||||||
|
p.AddConsumer(&f)
|
||||||
|
p.AddProducer(&f, v)
|
||||||
|
|
||||||
|
log.Println("handling...")
|
||||||
|
f.handleDatagram(buf[:n])
|
||||||
|
log.Println("handled")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
41
udp/packet.go
Normal file
41
udp/packet.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
ack uint32
|
||||||
|
nack uint32
|
||||||
|
seq uint32
|
||||||
|
|
||||||
|
data proxy.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalPacket(b []byte) (p Packet, err error) {
|
||||||
|
p.ack = binary.LittleEndian.Uint32(b[0:4])
|
||||||
|
p.nack = binary.LittleEndian.Uint32(b[4:8])
|
||||||
|
p.seq = binary.LittleEndian.Uint32(b[8:12])
|
||||||
|
|
||||||
|
p.data, err = proxy.UnmarshalSimplePacket(b[12:])
|
||||||
|
if err != nil {
|
||||||
|
return Packet{}, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Packet) Marshal() []byte {
|
||||||
|
data := p.data.Marshal()
|
||||||
|
header := make([]byte, 12)
|
||||||
|
|
||||||
|
binary.LittleEndian.PutUint32(header[0:4], p.ack)
|
||||||
|
binary.LittleEndian.PutUint32(header[4:8], p.nack)
|
||||||
|
binary.LittleEndian.PutUint32(header[8:12], p.seq)
|
||||||
|
|
||||||
|
return append(header, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Packet) Contents() []byte {
|
||||||
|
return p.data.Contents()
|
||||||
|
}
|
59
udp/packet_test.go
Normal file
59
udp/packet_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPacket_Marshal(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 18,
|
||||||
|
nack: 29,
|
||||||
|
seq: 431,
|
||||||
|
data: proxy.NewSimplePacket(testContent),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
marshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
// 12 header + 8 timestamp
|
||||||
|
assert.Len(t, marshalled, len(testContent)+12+8)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalPacket(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 18,
|
||||||
|
nack: 29,
|
||||||
|
seq: 431,
|
||||||
|
data: proxy.NewSimplePacket(testContent),
|
||||||
|
}
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Contents", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, p.Contents(), testContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Header", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, p.ack, uint32(18))
|
||||||
|
assert.Equal(t, p.nack, uint32(29))
|
||||||
|
assert.Equal(t, p.seq, uint32(431))
|
||||||
|
})
|
||||||
|
}
|
35
udp/wireshark_dissector.lua
Normal file
35
udp/wireshark_dissector.lua
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP")
|
||||||
|
|
||||||
|
ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement")
|
||||||
|
nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement")
|
||||||
|
seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number")
|
||||||
|
time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp")
|
||||||
|
proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data")
|
||||||
|
|
||||||
|
mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F }
|
||||||
|
|
||||||
|
function mpbl3p_udp.dissector(buffer, pinfo, tree)
|
||||||
|
if buffer:len() < 20 then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
pinfo.cols.protocol = "MPBL3P_UDP"
|
||||||
|
|
||||||
|
local ack = buffer(0, 4):le_uint()
|
||||||
|
local nack = buffer(4, 4):le_uint()
|
||||||
|
local seq = buffer(8, 4):le_uint()
|
||||||
|
|
||||||
|
local unix_time = buffer(buffer:len() - 8, 8):le_uint64()
|
||||||
|
|
||||||
|
local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack)
|
||||||
|
|
||||||
|
subtree:add(ack_F, ack)
|
||||||
|
subtree:add(nack_F, nack)
|
||||||
|
subtree:add(seq_F, seq)
|
||||||
|
subtree:add(time_F, NSTime.new(unix_time:tonumber()))
|
||||||
|
if buffer:len() > 20 then
|
||||||
|
subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
DissectorTable.get("udp.port"):add(1234, mpbl3p_udp)
|
65
utils/heap.go
Normal file
65
utils/heap.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrorEmptyHeap = errors.New("attempted to extract from empty heap")
|
||||||
|
|
||||||
|
// A MinHeap for Uint64
|
||||||
|
type Uint32Heap []uint32
|
||||||
|
|
||||||
|
func (h *Uint32Heap) swap(x, y int) {
|
||||||
|
(*h)[x] = (*h)[x] ^ (*h)[y]
|
||||||
|
(*h)[y] = (*h)[y] ^ (*h)[x]
|
||||||
|
(*h)[x] = (*h)[x] ^ (*h)[y]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Uint32Heap) Insert(new uint32) uint32 {
|
||||||
|
*h = append(*h, new)
|
||||||
|
|
||||||
|
child := len(*h) - 1
|
||||||
|
for child != 0 {
|
||||||
|
parent := (child - 1) / 2
|
||||||
|
if (*h)[parent] > (*h)[child] {
|
||||||
|
h.swap(parent, child)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
child = parent
|
||||||
|
}
|
||||||
|
|
||||||
|
return (*h)[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Uint32Heap) Extract() (uint32, error) {
|
||||||
|
if len(*h) == 0 {
|
||||||
|
return 0, ErrorEmptyHeap
|
||||||
|
}
|
||||||
|
min := (*h)[0]
|
||||||
|
|
||||||
|
(*h)[0] = (*h)[len(*h)-1]
|
||||||
|
*h = (*h)[:len(*h)-1]
|
||||||
|
|
||||||
|
parent := 0
|
||||||
|
for {
|
||||||
|
left, right := parent*2+1, parent*2+2
|
||||||
|
|
||||||
|
if (left < len(*h) && (*h)[parent] > (*h)[left]) || (right < len(*h) && (*h)[parent] > (*h)[right]) {
|
||||||
|
if right < len(*h) && (*h)[left] > (*h)[right] {
|
||||||
|
h.swap(parent, right)
|
||||||
|
parent = right
|
||||||
|
} else {
|
||||||
|
h.swap(parent, left)
|
||||||
|
parent = left
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return min, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Uint32Heap) Peek() (uint32, error) {
|
||||||
|
if len(*h) == 0 {
|
||||||
|
return 0, ErrorEmptyHeap
|
||||||
|
}
|
||||||
|
return (*h)[0], nil
|
||||||
|
}
|
54
utils/heap_test.go
Normal file
54
utils/heap_test.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SlowHeapSort(in []uint32) []uint32 {
|
||||||
|
out := make([]uint32, len(in))
|
||||||
|
|
||||||
|
var heap Uint32Heap
|
||||||
|
|
||||||
|
for _, x := range in {
|
||||||
|
heap.Insert(x)
|
||||||
|
}
|
||||||
|
for i := range out {
|
||||||
|
var err error
|
||||||
|
out[i], err = heap.Extract()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUint32Heap(t *testing.T) {
|
||||||
|
t.Run("EquivalentToMerge", func(t *testing.T) {
|
||||||
|
const ArrayLength = 50
|
||||||
|
|
||||||
|
sortedArray := make([]uint32, ArrayLength)
|
||||||
|
array := make([]uint32, ArrayLength)
|
||||||
|
|
||||||
|
for i := range array {
|
||||||
|
sortedArray[i] = uint32(i)
|
||||||
|
array[i] = uint32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Seed(time.Now().Unix())
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||||
|
rand.Shuffle(50, func(i, j int) { array[i], array[j] = array[j], array[i] })
|
||||||
|
|
||||||
|
heapSorted := SlowHeapSort(array)
|
||||||
|
|
||||||
|
assert.Equal(t, sortedArray, heapSorted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -1,13 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
var NextId = make(chan int)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
go func() {
|
|
||||||
i := 0
|
|
||||||
for {
|
|
||||||
NextId <- i
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user