merge develop into master #21

Merged
JakeHillion merged 149 commits from develop into master 2021-05-12 00:22:59 +01:00
37 changed files with 2434 additions and 355 deletions

View File

@ -1,12 +1,18 @@
---
kind: pipeline kind: pipeline
type: docker type: docker
name: default name: default
steps: steps:
- name: format
image: golang:1.16
commands:
- bash -c "gofmt -l . | wc -l | cmp -s <(echo 0) || (gofmt -l . && exit 1)"
- name: install - name: install
image: golang:1.15 image: golang:1.16
environment: environment:
GOPROXY: http://10.20.0.25:3142|direct GOPROXY: http://containers.internal.hillion.co.uk:3142,direct
volumes: volumes:
- name: cache - name: cache
path: /go path: /go
@ -14,7 +20,7 @@ steps:
- go test -i ./... - go test -i ./...
- name: test - name: test
image: golang:1.15 image: golang:1.16
volumes: volumes:
- name: cache - name: cache
path: /go path: /go
@ -22,7 +28,7 @@ steps:
- go test ./... - go test ./...
- name: build (debian) - name: build (debian)
image: golang:1.15-buster image: golang:1.16-buster
when: when:
event: event:
- push - push
@ -30,7 +36,10 @@ steps:
- name: cache - name: cache
path: /go path: /go
commands: commands:
- go build - GOOS=linux GOARCH=amd64 go build -o linux_amd64
- GOOS=linux GOARCH=arm GOARM=7 go build -o linux_arm_v7
- GOOS=freebsd GOARCH=amd64 go build -o freebsd_amd64
- GOOS=freebsd GOARCH=arm64 go build -o freebsd_arm64_v8a
- name: upload - name: upload
image: minio/mc image: minio/mc
@ -42,9 +51,18 @@ steps:
SECRET_KEY: SECRET_KEY:
from_secret: s3_secret_key from_secret: s3_secret_key
commands: commands:
- mc alias set s3 http://10.20.0.25:3900 $${ACCESS_KEY} $${SECRET_KEY} - mc alias set s3 https://s3.us-west-001.backblazeb2.com $${ACCESS_KEY} $${SECRET_KEY}
- mc cp mpbl3p s3/dissertation/binaries/debian/${DRONE_BRANCH} - mc cp linux_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_amd64
- mc cp linux_arm_v7 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_arm_v7
- mc cp freebsd_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_amd64
- mc cp freebsd_arm64_v8a s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_arm64_v8a
volumes: volumes:
- name: cache - name: cache
temp: { } temp: { }
---
kind: signature
hmac: 7960420c7d02f9bce56d6429b612676d24cbe1d1608cf44a77da9afc411eccb8
...

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
config.ini *.conf
logs/ logs/
# Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go # Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go

129
README.md
View File

@ -4,15 +4,130 @@
### Linux ### Linux
#### Policy Based Routing #### Policy Based Routing
ip route flush 11 ip route flush table 10
ip route add table 11 to 1.1.1.0/24 dev eth1 ip route add table 10 to 1.1.1.0/24 dev eth1
ip rule add from 1.1.1.4 table 11 priority 11 ip rule add from 1.1.1.4 table 10 priority 10
ip route flush 10 ip route flush table 11
ip route add table 10 to 1.1.1.0/24 dev eth2 ip route add table 11 to 1.1.1.0/24 dev eth2
ip rule add from 1.1.1.5 table 10 priority 10 ip rule add from 1.1.1.5 table 11 priority 11
#### ARP Flux #### ARP Flux
sysctl -w net.ipv4.conf.all.arp_announce=1 sysctl -w net.ipv4.conf.all.arp_announce=1
sysctl -w net.ipv4.conf.all.arp_ignore=2 sysctl -w net.ipv4.conf.all.arp_ignore=1
See http://kb.linuxvirtualserver.org/wiki/Using_arp_announce/arp_ignore_to_disable_ARP
### Setup Scripts
These are functional setup scripts that make the application run as intended on Linux.
### Remote Portal
#### Pre-Start
#!/bin/bash
set -e
## Set up variables
REMOTE_PORTAL_ADDRESS=A.B.C.D
## IPv4 Forwarding
sysctl -w net.ipv4.ip_forward=1
sysctl -w net.ipv4.conf.eth0.proxy_arp=1
## Transfer the local routing table to a much lower priority
(ip rule show | grep '20:') > /dev/null || ip rule add from all table local priority 20
ip rule del priority 0 2> /dev/null || true
## Ports to route locally
### MPBL3P
ip rule del priority 1 2> /dev/null || true
ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 1234 table local priority 1
### SSH
ip rule del priority 2 2> /dev/null || true
ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 22 table local priority 2
#### Post-Start
#!/bin/bash
set -e
## Set up variables
REMOTE_PORTAL_ADDRESS=A.B.C.D
## Tunnel addr/up
ip addr add 172.19.152.2/31 dev nc0
ip link set up nc0
# Route packets to the interface but not for nc via the tunnel
ip route flush table 19
ip route add table 19 to "$REMOTE_PORTAL_ADDRESS" via 172.19.152.3 dev nc0
ip rule del priority 19 2> /dev/null || true
ip rule add to "$REMOTE_PORTAL_ADDRESS" table 19 priority 19
### Local Portal
#### Pre-Start
#!/bin/bash
set -e
## Set up variables
GATEWAY_INTERFACE=eth0
GATEWAY_ADDRESS=10.36.12.1
## Fix ARP
sysctl -w net.ipv4.conf.all.arp_announce=1
sysctl -w net.ipv4.conf.all.arp_ignore=1
## IPv4 Forwarding
sysctl -w net.ipv4.ip_forward=1
## Gateway Interface Setup
ip addr flush dev "$GATEWAY_INTERFACE"
ip addr add "$GATEWAY_ADDRESS"/32 dev "$GATEWAY_INTERFACE"
ip link set up "$GATEWAY_INTERFACE"
## Per-Interface Routing Tables
### 10.10.0.0/24
ip route flush table 10
ip route add table 10 default via 10.10.0.1
ip rule del priority 10 2> /dev/null || true
ip rule add from 10.10.0.0/24 table 10 priority 10
### 192.168.0.0/24
ip route flush table 11
ip route add table 11 default via 192.168.0.1
ip rule del priority 11 2> /dev/null || true
ip rule add from 192.168.0.0/24 table 11 priority 11
#### Post-Start
#!/bin/bash
set -e
## Set up variables
REMOTE_PORTAL_ADDRESS=A.B.C.D
GATEWAY_INTERFACE=eth0
## Tunnel Address and Enable
ip addr add 172.19.152.3/31 dev nc0
ip link set up nc0
## Route Outbound Packets Correctly
ip route flush table 20
ip route add table 20 default via 172.19.152.2 dev nc0
ip rule del priority 20 2> /dev/null || true
ip rule add from "$REMOTE_PORTAL_ADDRESS" iif "$GATEWAY_INTERFACE" table 20 priority 20
## Route Inbound Packets Correctly
ip route flush table 21
ip route add table 21 to "$REMOTE_PORTAL_ADDRESS" dev "$GATEWAY_INTERFACE"
ip rule del priority 21 2> /dev/null || true
ip rule add to "$REMOTE_PORTAL_ADDRESS" table 21 priority 21
#### Client
Connect to `GATEWAY_INTERFACE` and set the IP to `REMOTE_PORTAL_ADDRESS`/32 with a gateway of `GATEWAY_ADDRESS`.

View File

@ -1,48 +1,57 @@
package config package config
import ( import (
"context"
"encoding/base64"
"fmt" "fmt"
"mpbl3p/crypto"
"mpbl3p/crypto/sharedkey"
"mpbl3p/proxy" "mpbl3p/proxy"
"mpbl3p/tcp" "mpbl3p/tcp"
"mpbl3p/tun" "mpbl3p/udp"
"mpbl3p/udp/congestion"
"time"
) )
// TODO: Delete this code as soon as an alternative is available func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
type UselessMac struct{}
func (UselessMac) CodeLength() int {
return 0
}
func (UselessMac) Generate([]byte) []byte {
return nil
}
func (u UselessMac) Verify([]byte, []byte) error {
return nil
}
func (c Configuration) Build() (*proxy.Proxy, error) {
p := proxy.NewProxy(0) p := proxy.NewProxy(0)
p.Generator = UselessMac{}
if c.Host.InterfaceName == "" { var g func() proxy.MacGenerator
c.Host.InterfaceName = "nc%d" var v func() proxy.MacVerifier
}
ss, err := tun.NewTun(c.Host.InterfaceName, 1500) switch c.Host.Crypto {
case "None":
g = func() proxy.MacGenerator { return crypto.None{} }
v = func() proxy.MacVerifier { return crypto.None{} }
case "Blake2s":
key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err := sharedkey.NewBlake2s(key); err != nil {
return nil, err
}
g = func() proxy.MacGenerator {
g, _ := sharedkey.NewBlake2s(key)
return g
}
v = func() proxy.MacVerifier {
v, _ := sharedkey.NewBlake2s(key)
return v
}
}
p.Source = ss p.Source = source
p.Sink = ss p.Sink = sink
for _, peer := range c.Peers { for _, peer := range c.Peers {
switch peer.Method { switch peer.Method {
case "TCP": case "TCP":
err := buildTcp(p, peer) if err := buildTcp(ctx, p, peer, g, v); err != nil {
if err != nil { return nil, err
}
case "UDP":
if err := buildUdp(ctx, p, peer, g, v); err != nil {
return nil, err return nil, err
} }
} }
@ -51,20 +60,82 @@ func (c Configuration) Build() (*proxy.Proxy, error) {
return p, nil return p, nil
} }
func buildTcp(p *proxy.Proxy, peer Peer) error { func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
} else {
laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) }
}
if peer.RemoteHost != "" { if peer.RemoteHost != "" {
f, err := tcp.InitiateFlow( f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort))
fmt.Sprintf("%s:", peer.LocalHost),
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
)
p.AddConsumer(f)
p.AddProducer(f, UselessMac{})
if err != nil {
return err return err
} }
err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{}) if !peer.DisableConsumer {
p.AddConsumer(ctx, f, g())
}
if !peer.DisableProducer {
p.AddProducer(ctx, f, v())
}
return nil
}
err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil {
return err
}
return nil
}
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
} else {
laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) }
}
var c func() udp.Congestion
switch peer.Congestion {
case "None":
c = func() udp.Congestion { return congestion.NewNone() }
default:
fallthrough
case "NewReno":
c = func() udp.Congestion { return congestion.NewNewReno() }
}
if peer.RemoteHost != "" {
f, err := udp.InitiateFlow(
laddr,
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
v(),
g(),
c(),
time.Duration(peer.KeepAlive)*time.Second,
)
if err != nil {
return err
}
if !peer.DisableConsumer {
p.AddConsumer(ctx, f, g())
}
if !peer.DisableProducer {
p.AddProducer(ctx, f, v())
}
return nil
}
err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,32 +1,96 @@
package config package config
import "github.com/go-playground/validator/v10" import (
"github.com/go-playground/validator/v10"
"log"
"net"
"strings"
)
var v = validator.New() var v = validator.New()
func init() {
if err := v.RegisterValidation("iface", func(fl validator.FieldLevel) bool {
name, ok := fl.Field().Interface().(string)
if ok && name != "" {
ifaces, err := net.Interfaces()
if err != nil {
log.Printf("error getting interfaces: %v", err)
return false
}
for _, i := range ifaces {
if i.Name == name {
return true
}
}
}
return false
}); err != nil {
panic(err)
}
}
type Configuration struct { type Configuration struct {
Host Host Host Host
Peers []Peer `validate:"dive"` Peers []Peer `validate:"dive"`
} }
type Host struct { type Host struct {
PrivateKey string `validate:"required"` Crypto string `validate:"required,oneof=None Blake2s"`
InterfaceName string SharedKey string `validate:"required_if=Crypto Blake2s"`
MTU uint `validate:"required,min=576"`
} }
type Peer struct { type Peer struct {
PublicKey string `validate:"required"` Method string `validate:"oneof=TCP UDP"`
Method string `validate:"oneof=TCP"`
LocalHost string `validate:"omitempty,ip"` LocalHost string `validate:"omitempty,ip|iface"`
LocalPort uint `validate:"max=65535"` LocalPort uint `validate:"max=65535"`
RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"` RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"`
RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"` RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"`
Congestion string `validate:"required_unless=Method TCP,omitempty,oneof=NewReno None"`
KeepAlive uint KeepAlive uint
Timeout uint Timeout uint
RetryWait uint RetryWait uint
DisableConsumer bool `validate:"omitempty,nefield=DisableProducer"`
DisableProducer bool `validate:"omitempty,nefield=DisableConsumer"`
}
func (p Peer) GetLocalHost() string {
if p.LocalHost == "" {
return ""
}
if err := v.Var(p.LocalHost, "ip"); err == nil {
return p.LocalHost
}
iface, err := net.InterfaceByName(p.LocalHost)
if err != nil {
panic(err)
}
if iface != nil {
addrs, err := iface.Addrs()
if err != nil {
panic(err)
}
if len(addrs) > 0 {
addr := addrs[0].String()
addr = strings.Split(addr, "/")[0]
log.Printf("resolved interface `%s` to `%v`", p.LocalHost, addr)
return addr
}
}
return "invalid"
} }
func (c Configuration) Validate() error { func (c Configuration) Validate() error {

15
crypto/none.go Normal file
View File

@ -0,0 +1,15 @@
package crypto
type None struct{}
func (None) CodeLength() int {
return 0
}
func (None) Generate([]byte) []byte {
return nil
}
func (None) Verify([]byte, []byte) error {
return nil
}

View File

@ -0,0 +1,40 @@
package sharedkey
import (
"bytes"
"golang.org/x/crypto/blake2s"
"mpbl3p/shared"
)
type Blake2s struct {
key []byte
}
func NewBlake2s(key []byte) (*Blake2s, error) {
_, err := blake2s.New128(key)
if err != nil {
return nil, err
}
return &Blake2s{key: key}, nil
}
func (b Blake2s) CodeLength() int {
return blake2s.Size128
}
func (b Blake2s) Generate(d []byte) []byte {
h, _ := blake2s.New128(b.key)
h.Write(d)
return h.Sum([]byte{})
}
func (b Blake2s) Verify(d []byte, s []byte) error {
h, _ := blake2s.New128(b.key)
h.Write(d)
sum := h.Sum([]byte{})
if !bytes.Equal(sum, s) {
return shared.ErrBadChecksum
}
return nil
}

View File

@ -0,0 +1,45 @@
package sharedkey
import (
"github.com/stretchr/testify/assert"
"math/rand"
"mpbl3p/shared"
"testing"
)
func TestBlake2s_GenerateVerify(t *testing.T) {
t.Run("GeneratedVerifies", func(t *testing.T) {
// ASSIGN
key := make([]byte, 16)
rand.Read(key)
buf := make([]byte, 500)
rand.Read(buf)
// ACT
b := Blake2s{key}
code := b.Generate(buf)
// ASSERT
err := b.Verify(buf, code)
assert.Nil(t, err)
})
t.Run("FlippedBitFailsVerify", func(t *testing.T) {
// ASSIGN
key := make([]byte, 16)
rand.Read(key)
buf := make([]byte, 500)
rand.Read(buf)
// ACT
b := Blake2s{key}
code := b.Generate(buf)
offset := rand.Intn(len(buf) * 8)
buf[offset/8] ^= 1 << (offset % 8)
// ASSERT
err := b.Verify(buf, code)
assert.Equal(t, shared.ErrBadChecksum, err)
})
}

41
flags/flags.go Normal file
View File

@ -0,0 +1,41 @@
package flags
import (
"errors"
"fmt"
goflags "github.com/jessevdk/go-flags"
"os"
)
var PrintedHelpErr = goflags.ErrHelp
var NotEnoughArgs = errors.New("not enough arguments")
type Options struct {
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`
ConfigFile string `short:"c" long:"config" description:"Configuration file location" value-name:"FILE"`
PidFile string `short:"p" long:"pid" description:"PID file location"`
Positional struct {
InterfaceName string `required:"yes" positional-arg-name:"INTERFACE-NAME" description:"Interface name"`
} `positional-args:"yes"`
}
func ParseFlags() (*Options, error) {
o := new(Options)
parser := goflags.NewParser(o, goflags.Default)
_, err := parser.Parse()
if err != nil {
parser.WriteHelp(os.Stdout)
return nil, err
}
if o.ConfigFile == "" {
o.ConfigFile = fmt.Sprintf(DefaultConfigFile, o.Positional.InterfaceName)
}
if o.PidFile == "" {
o.PidFile = fmt.Sprintf(DefaultPidFile, o.Positional.InterfaceName)
}
return o, nil
}

4
flags/locs_freebsd.go Normal file
View File

@ -0,0 +1,4 @@
package flags
const DefaultConfigFile = "/usr/local/etc/netcombiner/%s"
const DefaultPidFile = "/var/run/netcombiner/%s.pid"

4
flags/locs_linux.go Normal file
View File

@ -0,0 +1,4 @@
package flags
const DefaultConfigFile = "/etc/netcombiner/%s"
const DefaultPidFile = "/var/run/netcombiner/%s.pid"

8
go.mod
View File

@ -3,10 +3,12 @@ module mpbl3p
go 1.15 go 1.15
require ( require (
github.com/go-playground/assert/v2 v2.0.1 github.com/go-playground/validator/v10 v10.6.0
github.com/go-playground/validator/v10 v10.4.1 github.com/jessevdk/go-flags v1.5.0
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab
github.com/smartystreets/goconvey v1.6.4 // indirect github.com/smartystreets/goconvey v1.6.4 // indirect
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/net v0.0.0-20210326060303-6b1517762897 // indirect
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58
gopkg.in/ini.v1 v1.62.0 gopkg.in/ini.v1 v1.62.0
) )

28
go.sum
View File

@ -8,14 +8,16 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4=
github.com/go-playground/validator/v10 v10.6.0 h1:UGIt4xR++fD9QrBOoo/ascJfGe3AGHEB9s6COnss4Rk=
github.com/go-playground/validator/v10 v10.6.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab h1:dAXDRtXYxj4sTR5WeRuTFJGH18QMT6AUpUgRwedI6es=
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab/go.mod h1:N5a/Ll2ZNk5wjiLNW9LIiNtO9RNYcaYmcXSYKMYrlDg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
@ -26,17 +28,35 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210326060303-6b1517762897 h1:KrsHThm5nFk34YtATK1LsThyGhGbGe1olrte/HInHvs=
golang.org/x/net v0.0.0-20210326060303-6b1517762897/go.mod h1:uSPa2vr4CLtc/ILN5odXGNXS6mhrKVzTaCXzk9m6W3k=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210324051608-47abb6519492 h1:Paq34FxTluEPvVyayQqMPgHm+vTOrIifmcYxFBx9TLg=
golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58 h1:HiPOx0boQr3qv0HkZ4fGLtTXJ5tmkbv0d8UmkNcxdv0=
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58/go.mod h1:Dz+cq5bnrai9EpgYj4GDof/+qaGzbRWbeaAOs1bUYa0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU=

127
main.go
View File

@ -1,36 +1,155 @@
package main package main
import ( import (
"context"
"errors"
"fmt"
"log" "log"
"mpbl3p/config" "mpbl3p/config"
"mpbl3p/flags"
"mpbl3p/tun"
"os" "os"
"os/signal" "os/signal"
"strconv"
"syscall" "syscall"
) )
const (
ENV_NC_TUN_FD = "NC_TUN_FD"
ENV_NC_CONFIG_PATH = "NC_CONFIG_PATH"
)
func main() { func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
if _, exists := os.LookupEnv(ENV_NC_TUN_FD); !exists {
// we are the parent process
// 1) process arguments
// 2) validate config
// 2) create a tun adapter
// 3) spawn a child
// 4) exit
o, err := flags.ParseFlags()
if err != nil {
if errors.Is(err, flags.PrintedHelpErr) {
return
}
panic(err)
}
log.Println("loading config...") log.Println("loading config...")
c, err := config.LoadConfig("config.ini") c, err := config.LoadConfig(o.ConfigFile)
if err != nil {
log.Fatalf("error validating config: %s", err.Error())
return
}
log.Println("creating tun adapter...")
t, err := tun.NewTun(o.Positional.InterfaceName, int(c.Host.MTU))
if err != nil { if err != nil {
panic(err) panic(err)
} }
if o.Foreground {
if err := os.Setenv(ENV_NC_TUN_FD, fmt.Sprintf("%d", t.File().Fd())); err != nil {
panic(err)
}
if err := os.Setenv(ENV_NC_CONFIG_PATH, o.ConfigFile); err != nil {
panic(err)
}
log.Println("switching to foreground")
goto FOREGROUND
}
files := make([]*os.File, 4)
files[0], _ = os.Open(os.DevNull) // stdin
files[1], _ = os.Open(os.DevNull) // stderr
files[2], _ = os.Open(os.DevNull) // stdout
files[3] = t.File()
env := os.Environ()
env = append(env, fmt.Sprintf("%s=3", ENV_NC_TUN_FD))
env = append(env, fmt.Sprintf("%s=%s", ENV_NC_CONFIG_PATH, o.ConfigFile))
attr := os.ProcAttr{
Env: env,
Files: files,
}
path, err := os.Executable()
if err != nil {
panic(err)
}
process, err := os.StartProcess(
path,
os.Args,
&attr,
)
if err != nil {
panic(err)
}
pidFile, err := os.Create(o.PidFile)
if err != nil {
panic(err)
}
if _, err := fmt.Fprintf(pidFile, "%d", process.Pid); err != nil {
panic(err)
}
_ = process.Release()
return
}
// we are the child process
// 1) recreate tun adapter from file descriptor
// 2) launch proxy
FOREGROUND:
log.Println("loading config...")
c, err := config.LoadConfig(os.Getenv(ENV_NC_CONFIG_PATH))
if err != nil {
panic(err)
}
log.Println("connecting tun adapter...")
tunFd, err := strconv.ParseUint(os.Getenv(ENV_NC_TUN_FD), 10, 64)
if err != nil {
panic(err)
}
t, err := tun.NewFromFile(uintptr(tunFd), int(c.Host.MTU))
if err != nil {
panic(err)
}
defer func() {
if err := t.Close(); err != nil {
panic(err)
}
}()
log.Println("building config...") log.Println("building config...")
p, err := c.Build() ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p, err := c.Build(ctx, t, t)
if err != nil { if err != nil {
panic(err) panic(err)
} }
log.Println("starting...") log.Println("starting proxy...")
p.Start() p.Start()
log.Println("running") log.Println("proxy started")
signals := make(chan os.Signal) signals := make(chan os.Signal)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
<-signals <-signals
log.Println("exiting...")
} }

View File

@ -1,67 +0,0 @@
package mocks
import "time"
type MockPerfectBiConn struct {
directionA chan byte
directionB chan byte
}
func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn {
return MockPerfectBiConn{
directionA: make(chan byte, bufSize),
directionB: make(chan byte, bufSize),
}
}
func (bc MockPerfectBiConn) SideA() MockPerfectConn {
return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiConn) SideB() MockPerfectConn {
return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectConn struct {
inbound chan byte
outbound chan byte
}
func (c MockPerfectConn) SetWriteDeadline(time.Time) error {
return nil
}
func (c MockPerfectConn) Read(p []byte) (n int, err error) {
for i := range p {
if i == 0 {
p[i] = <-c.inbound
} else {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
}
return len(p), nil
}
func (c MockPerfectConn) Write(p []byte) (n int, err error) {
for _, b := range p {
c.outbound <- b
}
return len(p), nil
}
func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) {
for i := range p {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
return len(p), nil
}

View File

@ -1,6 +1,8 @@
package mocks package mocks
import "mpbl3p/shared" import (
"mpbl3p/shared"
)
type AlmostUselessMac struct{} type AlmostUselessMac struct{}

62
mocks/packetconn.go Normal file
View File

@ -0,0 +1,62 @@
package mocks
import "net"
type MockPerfectBiPacketConn struct {
directionA chan []byte
directionB chan []byte
}
func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn {
return MockPerfectBiPacketConn{
directionA: make(chan []byte, bufSize),
directionB: make(chan []byte, bufSize),
}
}
func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn {
return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn {
return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectPacketConn struct {
inbound chan []byte
outbound chan []byte
}
func (c MockPerfectPacketConn) Write(b []byte) (int, error) {
c.outbound <- b
return len(b), nil
}
func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
c.outbound <- b
return len(b), nil
}
func (c MockPerfectPacketConn) LocalAddr() net.Addr {
return &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
}
}
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
p := <-c.inbound
return copy(b, p), &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
}, nil
}
func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) {
select {
case b := <-c.inbound:
return copy(p, b), nil
default:
return 0, nil
}
}

95
mocks/streamconn.go Normal file
View File

@ -0,0 +1,95 @@
package mocks
import (
"net"
"time"
)
type MockPerfectBiStreamConn struct {
directionA chan byte
directionB chan byte
}
func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn {
return MockPerfectBiStreamConn{
directionA: make(chan byte, bufSize),
directionB: make(chan byte, bufSize),
}
}
func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn {
return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB}
}
func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn {
return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA}
}
type MockPerfectStreamConn struct {
inbound chan byte
outbound chan byte
}
type Conn interface {
Read(b []byte) (n int, err error)
Write(b []byte) (n int, err error)
SetWriteDeadline(time.Time) error
// For printing
LocalAddr() net.Addr
RemoteAddr() net.Addr
}
func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) {
for i := range p {
if i == 0 {
p[i] = <-c.inbound
} else {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
}
return len(p), nil
}
func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) {
for _, b := range p {
c.outbound <- b
}
return len(p), nil
}
func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// Only used for printing flow information
func (c MockPerfectStreamConn) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 499,
}
}
func (c MockPerfectStreamConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 500,
}
}
func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) {
for i := range p {
select {
case b := <-c.inbound:
p[i] = b
default:
return i, nil
}
}
return len(p), nil
}

View File

@ -5,56 +5,42 @@ import (
"time" "time"
) )
type Packet struct { type Packet interface {
Data []byte Marshal() []byte
timestamp time.Time Contents() []byte
} }
// create a packet from the raw data of an IP packet type SimplePacket []byte
func NewPacket(data []byte) Packet {
return Packet{
Data: data,
timestamp: time.Now(),
}
}
// rebuild a packet from the wrapped format
func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) {
// the MAC is the last N bytes
data := raw[:len(raw)-verifier.CodeLength()]
sum := raw[len(raw)-verifier.CodeLength():]
if err := verifier.Verify(data, sum); err != nil {
return Packet{}, err
}
p := Packet{
Data: data[:len(data)-8],
}
unixTime := int64(binary.LittleEndian.Uint64(data[len(data)-8:]))
p.timestamp = time.Unix(unixTime, 0)
return p, nil
}
// get the raw data of the IP packet // get the raw data of the IP packet
func (p Packet) Raw() []byte { func (p SimplePacket) Marshal() []byte {
return p.Data return p
} }
// produce the wrapped format of a packet func (p SimplePacket) Contents() []byte {
func (p Packet) Marshal(generator MacGenerator) []byte { return p
// length of data + length of timestamp (8 byte) + length of checksum }
slice := make([]byte, len(p.Data)+8+generator.CodeLength())
func AppendMac(b []byte, g MacGenerator) []byte {
copy(slice, p.Data) footer := make([]byte, 8)
unixTime := uint64(time.Now().Unix())
unixTime := uint64(p.timestamp.Unix()) binary.LittleEndian.PutUint64(footer, unixTime)
binary.LittleEndian.PutUint64(slice[len(p.Data):], unixTime)
b = append(b, footer...)
mac := generator.Generate(slice)
copy(slice[len(p.Data)+8:], mac) mac := g.Generate(b)
return append(b, mac...)
return slice }
func StripMac(b []byte, v MacVerifier) ([]byte, error) {
data := b[:len(b)-v.CodeLength()]
sum := b[len(b)-v.CodeLength():]
if err := v.Verify(data, sum); err != nil {
return nil, err
}
// TODO: Verify timestamp
return data[:len(data)-8], nil
} }

View File

@ -4,31 +4,59 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"mpbl3p/mocks" "mpbl3p/mocks"
"mpbl3p/shared"
"testing" "testing"
) )
func TestPacket_Marshal(t *testing.T) { func TestAppendMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testPacket := NewPacket(testContent)
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac{}
testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
marshalled := testPacket.Marshal(testMac) assert.Len(t, appended, len(testMarshalled)+8+4)
})
assert.Len(t, marshalled, len(testContent)+8+4) t.Run("Mac", func(t *testing.T) {
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:])
})
t.Run("Original", func(t *testing.T) {
assert.Equal(t, testMarshalled, appended[:len(testMarshalled)])
}) })
} }
func TestUnmarshalPacket(t *testing.T) { func TestStripMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testPacket := NewPacket(testContent)
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac{}
testMarshalled := testPacket.Marshal(testMac) testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled, testMac) cut, err := StripMac(appended, testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Len(t, p.Raw(), len(testContent)) assert.Len(t, cut, len(testMarshalled))
})
t.Run("IncorrectMac", func(t *testing.T) {
badMac := make([]byte, len(testMarshalled)+4)
copy(badMac, testMarshalled)
copy(badMac[:len(testMarshalled)], "dcba")
_, err := StripMac(badMac, testMac)
assert.Error(t, err, shared.ErrBadChecksum)
})
t.Run("Original", func(t *testing.T) {
cut, err := StripMac(appended, testMac)
require.Nil(t, err)
assert.Equal(t, testMarshalled, cut)
}) })
} }

View File

@ -1,22 +1,24 @@
package proxy package proxy
import ( import (
"context"
"errors"
"log" "log"
"time" "time"
) )
type Producer interface { type Producer interface {
IsAlive() bool IsAlive() bool
Produce(MacVerifier) (Packet, error) Produce(context.Context, MacVerifier) (Packet, error)
} }
type Consumer interface { type Consumer interface {
IsAlive() bool IsAlive() bool
Consume(Packet, MacGenerator) error Consume(context.Context, Packet, MacGenerator) error
} }
type Reconnectable interface { type Reconnectable interface {
Reconnect() error Reconnect(context.Context) error
} }
type Source interface { type Source interface {
@ -31,8 +33,6 @@ type Proxy struct {
Source Source Source Source
Sink Sink Sink Sink
Generator MacGenerator
proxyChan chan Packet proxyChan chan Packet
sinkChan chan Packet sinkChan chan Packet
} }
@ -67,7 +67,7 @@ func (p Proxy) Start() {
}() }()
} }
func (p Proxy) AddConsumer(c Consumer) { func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
go func() { go func() {
_, reconnectable := c.(Reconnectable) _, reconnectable := c.(Reconnectable)
@ -75,28 +75,42 @@ func (p Proxy) AddConsumer(c Consumer) {
if reconnectable { if reconnectable {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect `%v`\n", c) log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect() err = c.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
} }
log.Printf("connected `%v`\n", c) log.Printf("connected consumer `%v`\n", c)
} }
for c.IsAlive() { for c.IsAlive() {
if err := c.Consume(<-p.proxyChan, p.Generator); err != nil { select {
case <-ctx.Done():
log.Printf("closed consumer `%v` (context)\n", c)
return
case packet := <-p.proxyChan:
if err := c.Consume(ctx, packet, g); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
log.Println(err) log.Println(err)
break break
} }
} }
} }
}
log.Printf("closed connection `%v`\n", c) log.Printf("closed consumer `%v`\n", c)
}() }()
} }
func (p Proxy) AddProducer(pr Producer, v MacVerifier) { func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
go func() { go func() {
_, reconnectable := pr.(Reconnectable) _, reconnectable := pr.(Reconnectable)
@ -104,25 +118,42 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
if reconnectable { if reconnectable {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect `%v`\n", pr) log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect() err = pr.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
if ctx.Err() != nil {
return
} }
log.Printf("connected `%v`\n", pr)
}
log.Printf("connected producer `%v`\n", pr)
} }
for pr.IsAlive() { for pr.IsAlive() {
if packet, err := pr.Produce(v); err != nil { if packet, err := pr.Produce(ctx, v); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
log.Println(err) log.Println(err)
break break
} else { } else {
p.sinkChan <- packet select {
case <-ctx.Done():
log.Printf("closed producer `%v` (context)\n", pr)
return
case p.sinkChan <- packet:
}
} }
} }
} }
log.Printf("closed connection `%v`\n", pr) log.Printf("closed producer `%v`\n", pr)
}() }()
} }

View File

@ -4,3 +4,4 @@ import "errors"
var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrBadChecksum = errors.New("the packet had a bad checksum")
var ErrDeadConnection = errors.New("the connection is dead") var ErrDeadConnection = errors.New("the connection is dead")
var ErrNotEnoughBytes = errors.New("not enough bytes")

View File

@ -1,8 +1,10 @@
package tcp package tcp
import ( import (
"bufio"
"context"
"encoding/binary" "encoding/binary"
"errors" "fmt"
"io" "io"
"mpbl3p/proxy" "mpbl3p/proxy"
"mpbl3p/shared" "mpbl3p/shared"
@ -11,16 +13,18 @@ import (
"time" "time"
) )
var ErrNotEnoughBytes = errors.New("not enough bytes")
type Conn interface { type Conn interface {
Read(b []byte) (n int, err error) Read(b []byte) (n int, err error)
Write(b []byte) (n int, err error) Write(b []byte) (n int, err error)
SetWriteDeadline(time.Time) error SetWriteDeadline(time.Time) error
// For printing
LocalAddr() net.Addr
RemoteAddr() net.Addr
} }
type InitiatedFlow struct { type InitiatedFlow struct {
Local string Local func() string
Remote string Remote string
mu sync.RWMutex mu sync.RWMutex
@ -28,21 +32,64 @@ type InitiatedFlow struct {
Flow Flow
} }
func (f *InitiatedFlow) String() string {
return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local(), f.Remote)
}
type Flow struct { type Flow struct {
conn Conn conn Conn
isAlive bool isAlive bool
toConsume, produced chan []byte
consumeErrors, produceErrors chan error
} }
func InitiateFlow(local, remote string) (*InitiatedFlow, error) { func NewFlow() Flow {
return Flow{
toConsume: make(chan []byte),
produced: make(chan []byte),
consumeErrors: make(chan error),
produceErrors: make(chan error),
}
}
func NewFlowConn(ctx context.Context, conn Conn) Flow {
f := Flow{
conn: conn,
isAlive: true,
toConsume: make(chan []byte),
produced: make(chan []byte),
consumeErrors: make(chan error),
produceErrors: make(chan error),
}
go f.produceMarshalled(ctx)
go f.consumeMarshalled(ctx)
return f
}
func (f Flow) String() string {
return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr())
}
func (f *Flow) IsAlive() bool {
return f.isAlive
}
func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) {
f := InitiatedFlow{ f := InitiatedFlow{
Local: local, Local: local,
Remote: remote, Remote: remote,
Flow: NewFlow(),
} }
return &f, nil return &f, nil
} }
func (f *InitiatedFlow) Reconnect() error { func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
@ -50,7 +97,7 @@ func (f *InitiatedFlow) Reconnect() error {
return nil return nil
} }
localAddr, err := net.ResolveTCPAddr("tcp", f.Local) localAddr, err := net.ResolveTCPAddr("tcp", f.Local())
if err != nil { if err != nil {
return err return err
} }
@ -65,80 +112,113 @@ func (f *InitiatedFlow) Reconnect() error {
return err return err
} }
err = conn.SetWriteBuffer(0) if err := conn.SetWriteBuffer(0); err != nil {
if err != nil {
return err return err
} }
f.conn = conn f.conn = conn
f.isAlive = true f.isAlive = true
go f.produceMarshalled(ctx)
go f.consumeMarshalled(ctx)
return nil return nil
} }
func (f *Flow) IsAlive() bool { func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
return f.isAlive
}
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Consume(p, g) return f.Flow.Consume(ctx, p, g)
} }
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx, v)
}
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive { if !f.isAlive {
return shared.ErrDeadConnection return shared.ErrDeadConnection
} }
data := p.Marshal(g) select {
err = f.consumeMarshalled(data) case err := <-f.consumeErrors:
if err != nil {
f.isAlive = false f.isAlive = false
} return err
return default:
} }
func (f *Flow) consumeMarshalled(data []byte) error { marshalled := p.Marshal()
data := proxy.AppendMac(marshalled, g)
prefixedData := make([]byte, len(data)+4) prefixedData := make([]byte, len(data)+4)
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
copy(prefixedData[4:], data) copy(prefixedData[4:], data)
select {
case f.toConsume <- prefixedData:
case <-ctx.Done():
return ctx.Err()
}
return nil
}
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
var data []byte
select {
case <-ctx.Done():
return nil, ctx.Err()
case data = <-f.produced:
case err := <-f.produceErrors:
f.isAlive = false
return nil, err
}
b, err := proxy.StripMac(data, v)
if err != nil {
return nil, err
}
return proxy.SimplePacket(b), nil
}
func (f *Flow) consumeMarshalled(ctx context.Context) {
for {
data := <-f.toConsume
err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil { if err != nil {
return err f.consumeErrors <- err
return
} }
_, err = f.conn.Write(prefixedData) _, err = f.conn.Write(data)
return err
}
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(v)
}
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return proxy.Packet{}, shared.ErrDeadConnection
}
data, err := f.produceMarshalled()
if err != nil { if err != nil {
f.isAlive = false f.consumeErrors <- err
return proxy.Packet{}, err return
}
}
} }
return proxy.UnmarshalPacket(data, v) func (f *Flow) produceMarshalled(ctx context.Context) {
} buf := bufio.NewReader(f.conn)
func (f *Flow) produceMarshalled() ([]byte, error) { for {
lengthBytes := make([]byte, 4) lengthBytes := make([]byte, 4)
if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil { if n, err := io.LimitReader(buf, 4).Read(lengthBytes); err != nil {
return nil, err f.produceErrors <- err
return
} else if n != 4 { } else if n != 4 {
return nil, ErrNotEnoughBytes f.produceErrors <- shared.ErrNotEnoughBytes
return
} }
length := binary.LittleEndian.Uint32(lengthBytes) length := binary.LittleEndian.Uint32(lengthBytes)
@ -146,12 +226,14 @@ func (f *Flow) produceMarshalled() ([]byte, error) {
var read uint32 var read uint32
for read < length { for read < length {
if n, err := io.LimitReader(f.conn, int64(length-read)).Read(dataBytes[read:]); err != nil { if n, err := io.LimitReader(buf, int64(length-read)).Read(dataBytes[read:]); err != nil {
return nil, err f.produceErrors <- err
return
} else { } else {
read += uint32(n) read += uint32(n)
} }
} }
return dataBytes, nil f.produced <- dataBytes
}
} }

View File

@ -1,8 +1,9 @@
package tcp package tcp
import ( import (
"context"
"encoding/binary" "encoding/binary"
"github.com/go-playground/assert/v2" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"mpbl3p/mocks" "mpbl3p/mocks"
"mpbl3p/proxy" "mpbl3p/proxy"
@ -11,15 +12,15 @@ import (
func TestFlow_Consume(t *testing.T) { func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.NewPacket(testContent) testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true} flowA := NewFlowConn(context.Background(), testConn.SideA())
err := flowA.Consume(testPacket, testMac) err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err) require.Nil(t, err)
buf := make([]byte, 100) buf := make([]byte, 100)
@ -39,28 +40,28 @@ func TestFlow_Produce(t *testing.T) {
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true} flowA := NewFlowConn(context.Background(), testConn.SideA())
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(testMac) p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Raw())) assert.Equal(t, len(testContent), len(p.Contents()))
}) })
t.Run("Value", func(t *testing.T) { t.Run("Value", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := Flow{conn: testConn.SideA(), isAlive: true} flowA := NewFlowConn(context.Background(), testConn.SideA())
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(testMac) p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, testContent, string(p.Raw())) assert.Equal(t, testContent, string(p.Contents()))
}) })
} }

View File

@ -1,12 +1,13 @@
package tcp package tcp
import ( import (
"context"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
"net" "net"
) )
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error {
laddr, err := net.ResolveTCPAddr("tcp", local) laddr, err := net.ResolveTCPAddr("tcp", local)
if err != nil { if err != nil {
return err return err
@ -24,17 +25,20 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error {
panic(err) panic(err)
} }
err = conn.SetWriteBuffer(0) if err := conn.SetWriteBuffer(0); err != nil {
if err != nil {
panic(err) panic(err)
} }
f := Flow{conn: conn, isAlive: true} f := NewFlowConn(ctx, conn)
log.Printf("received new connection: %v\n", f) log.Printf("received new tcp connection: %v\n", f)
p.AddConsumer(&f) if enableConsumers {
p.AddProducer(&f, v) p.AddConsumer(ctx, &f, g())
}
if enableProducers {
p.AddProducer(ctx, &f, v())
}
} }
}() }()

View File

@ -1,85 +1,65 @@
package tun package tun
import ( import (
"github.com/pkg/taptun" wgtun "golang.zx2c4.com/wireguard/tun"
"io" "io"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
"net"
"os" "os"
"strings"
"sync"
"time"
) )
type SourceSink struct { type SourceSink struct {
tun *taptun.Tun tun wgtun.Device
bufferSize int
up bool
upMu sync.Mutex
} }
func NewTun(namingScheme string, bufferSize int) (ss *SourceSink, err error) { func NewTun(name string, mtu int) (t wgtun.Device, err error) {
ss = &SourceSink{} return wgtun.CreateTUN(name, mtu)
}
ss.tun, err = taptun.NewTun(namingScheme) func NewFromFile(fd uintptr, mtu int) (ss *SourceSink, err error) {
ss = new(SourceSink)
file := os.NewFile(fd, "")
ss.tun, err = wgtun.CreateTUNFromFile(file, mtu)
if err != nil { if err != nil {
return return
} }
ss.bufferSize = bufferSize
ss.upMu.Lock()
go func() {
defer ss.upMu.Unlock()
for {
iface, err := net.InterfaceByName(ss.tun.String())
if err != nil {
panic(err)
}
if strings.Contains(iface.Flags.String(), "up") {
log.Println("tun is up")
ss.up = true
return return
} }
time.Sleep(100 * time.Millisecond)
}
}()
return func (t *SourceSink) Close() error {
return t.tun.Close()
} }
func (t *SourceSink) Source() (proxy.Packet, error) { func (t *SourceSink) Source() (proxy.Packet, error) {
if !t.up { mtu, err := t.tun.MTU()
t.upMu.Lock() if err != nil {
t.upMu.Unlock() return nil, err
} }
buf := make([]byte, t.bufferSize) buf := make([]byte, mtu+4)
read, err := t.tun.Read(buf) read, err := t.tun.Read(buf, 4)
if err != nil { if err != nil {
return proxy.Packet{}, err return nil, err
} }
if read == 0 { if read == 0 {
return proxy.Packet{}, io.EOF return nil, io.EOF
} }
return proxy.NewPacket(buf[:read]), nil return proxy.SimplePacket(buf[4 : read+4]), nil
} }
var good, bad float64 var good, bad float64
func (t *SourceSink) Sink(packet proxy.Packet) error { func (t *SourceSink) Sink(packet proxy.Packet) error {
if !t.up { // make space for tun header
t.upMu.Lock() content := make([]byte, len(packet.Contents())+4)
t.upMu.Unlock() copy(content[4:], packet.Contents())
}
_, err := t.tun.Write(packet.Raw()) _, err := t.tun.Write(content, 4)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case *os.PathError: case *os.PathError:

16
udp/congestion.go Normal file
View File

@ -0,0 +1,16 @@
package udp
import (
"context"
"time"
)
type Congestion interface {
Sequence(ctx context.Context) (uint32, error)
NextAck() uint32
NextNack() uint32
ReceivedPacket(seq, nack, ack uint32)
AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
}

284
udp/congestion/newreno.go Normal file
View File

@ -0,0 +1,284 @@
package congestion
import (
"context"
"fmt"
"math"
"sort"
"sync"
"sync/atomic"
"time"
)
const RttExponentialFactor = 0.1
const RttLossDelay = 1.5
type NewReno struct {
sequence chan uint32
inFlight []flightInfo
lastSent time.Time
inFlightMu sync.Mutex
awaitingAck sortableFlights
ack, nack uint32
lastAck, lastNack uint32
ackNackMu sync.Mutex
rttNanos float64
windowSize, windowCount uint32
slowStart bool
windowNotifier chan struct{}
}
type flightInfo struct {
time time.Time
sequence uint32
}
type sortableFlights []flightInfo
func (f sortableFlights) Len() int {
return len(f)
}
func (f sortableFlights) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func (f sortableFlights) Less(i, j int) bool {
return f[i].sequence < f[j].sequence
}
func (c *NewReno) String() string {
return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, len(c.inFlight), c.lastAck, c.lastNack)
}
func NewNewReno() *NewReno {
c := NewReno{
sequence: make(chan uint32),
windowNotifier: make(chan struct{}),
windowSize: 1,
rttNanos: float64((10 * time.Millisecond).Nanoseconds()),
slowStart: true,
}
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c
}
func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
// decide what acks and nacks to send
if seq != 0 {
c.receivedSequence(seq)
}
// decide how window size was affected
if nack != 0 {
c.receivedNack(nack)
}
if ack != 0 {
c.receivedAck(ack)
}
}
func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
for len(c.inFlight) >= int(c.windowSize) {
<-c.windowNotifier
}
c.inFlightMu.Lock()
defer c.inFlightMu.Unlock()
var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return 0, ctx.Err()
}
t := time.Now()
c.inFlight = append(c.inFlight, flightInfo{
time: t,
sequence: s,
})
c.lastSent = t
return s, nil
}
func (c *NewReno) NextAck() uint32 {
a := c.ack
atomic.StoreUint32(&c.lastAck, a)
return a
}
func (c *NewReno) NextNack() uint32 {
n := c.nack
atomic.StoreUint32(&c.lastNack, n)
return n
}
func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
for {
rtt := time.Duration(math.Round(c.rttNanos))
time.Sleep(rtt / 2)
c.checkNack()
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
return 0, nil // no ack needed
}
// CASE 2: No message sent within the keepalive time
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
return c.Sequence(ctx) // require an ack
}
}
}
func (c *NewReno) receivedSequence(seq uint32) {
c.ackNackMu.Lock()
defer c.ackNackMu.Unlock()
if seq < c.nack || seq < c.ack {
// packet received out of order has already been cumulatively NACKed
// or duplicate packet received and already ACKed
return
}
if seq != c.ack+1 && seq != c.nack+1 {
c.awaitingAck = append(c.awaitingAck, flightInfo{
time: time.Now(),
sequence: seq,
})
return // if this seq doesn't change the ack field, updateAck will still not do anything useful
}
sort.Sort(c.awaitingAck)
c.updateAck(seq)
}
func (c *NewReno) checkNack() {
c.ackNackMu.Lock()
defer c.ackNackMu.Unlock()
if len(c.awaitingAck) == 0 {
return
}
sort.Sort(c.awaitingAck)
lossThreshold := time.Duration(c.rttNanos * RttLossDelay)
if c.awaitingAck[0].time.Before(time.Now().Add(-lossThreshold)) {
// if the next packet sequence to ack was received more than an rttlossdelay ago
// mark the packet(s) blocking it as missing with a nack
// then update ack from the delayed packet
c.nack = c.awaitingAck[0].sequence - 1
c.updateAck(c.nack)
}
}
func (c *NewReno) updateAck(start uint32) {
a := start
var removed int
for _, e := range c.awaitingAck {
if e.sequence == a+1 {
a = e.sequence
removed++
} else {
break
}
}
c.ack = a
c.awaitingAck = c.awaitingAck[removed:]
}
func (c *NewReno) receivedNack(nack uint32) {
c.ackNackMu.Lock()
defer c.ackNackMu.Unlock()
c.inFlightMu.Lock()
defer c.inFlightMu.Unlock()
// as both ack and nack are cumulative, inflight will always be ordered by seq
var i int
for i < len(c.inFlight) && c.inFlight[i].sequence <= nack {
i++
}
if i == 0 {
return
}
c.slowStart = false
c.inFlight = c.inFlight[i:]
for {
s := c.windowSize
if s == 1 || atomic.CompareAndSwapUint32(&c.windowSize, s, s/2) {
break
}
}
select {
case c.windowNotifier <- struct{}{}:
default:
}
}
func (c *NewReno) receivedAck(ack uint32) {
c.ackNackMu.Lock()
defer c.ackNackMu.Unlock()
c.inFlightMu.Lock()
defer c.inFlightMu.Unlock()
// as both ack and nack are cumulative, inflight will always be ordered by seq
var i int
for i < len(c.inFlight) && c.inFlight[i].sequence <= ack {
rtt := float64(time.Now().Sub(c.inFlight[i].time).Nanoseconds())
c.rttNanos = c.rttNanos*(1-RttExponentialFactor) + rtt*RttExponentialFactor
i++
}
if i == 0 {
return
}
c.inFlight = c.inFlight[i:]
if c.slowStart {
atomic.AddUint32(&c.windowSize, uint32(i))
} else {
c.windowCount += uint32(i)
s := c.windowSize
if c.windowCount > s {
c.windowCount -= s
atomic.AddUint32(&c.windowSize, 1)
}
}
select {
case c.windowNotifier <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,369 @@
package congestion
import (
"context"
"github.com/stretchr/testify/assert"
"sort"
"testing"
"time"
)
type congestionPacket struct {
seq, nack, ack uint32
}
type newRenoTest struct {
sideA, sideB *NewReno
aOutbound, bOutbound chan congestionPacket
aInbound, bInbound chan congestionPacket
halfRtt time.Duration
}
func newNewRenoTest(rtt time.Duration) *newRenoTest {
return &newRenoTest{
sideA: NewNewReno(),
sideB: NewNewReno(),
aOutbound: make(chan congestionPacket),
bOutbound: make(chan congestionPacket),
aInbound: make(chan congestionPacket),
bInbound: make(chan congestionPacket),
halfRtt: rtt / 2,
}
}
func (n *newRenoTest) Start(ctx context.Context) {
type packetWithTime struct {
t time.Time
p congestionPacket
}
go func() {
aOutboundDelayed := make(chan packetWithTime, 128)
bOutboundDelayed := make(chan packetWithTime, 128)
delayer := func(tp chan packetWithTime, cp chan congestionPacket) {
for {
select {
case p := <-tp:
s := p.t.Add(n.halfRtt).Sub(time.Now())
time.Sleep(s)
cp <- p.p
case <-ctx.Done():
return
}
}
}
go delayer(aOutboundDelayed, n.bInbound)
go delayer(bOutboundDelayed, n.aInbound)
for {
select {
case <-ctx.Done():
return
case p := <-n.aOutbound:
aOutboundDelayed <- packetWithTime{t: time.Now(), p: p}
case p := <-n.bOutbound:
bOutboundDelayed <- packetWithTime{t: time.Now(), p: p}
}
}
}()
}
func (n *newRenoTest) RunSideA(ctx context.Context) {
go func() {
for {
select {
case <-ctx.Done():
return
case p := <-n.aInbound:
n.sideA.ReceivedPacket(p.seq, p.nack, p.ack)
}
}
}()
go func() {
for {
seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
if err != nil {
return
}
if seq != 0 {
// skip keepalive
// required to ensure AwaitEarlyUpdate terminates
continue
}
p := congestionPacket{
seq: seq,
nack: n.sideA.NextNack(),
ack: n.sideA.NextAck(),
}
n.aOutbound <- p
}
}()
}
func (n *newRenoTest) RunSideB(ctx context.Context) {
go func() {
for {
select {
case <-ctx.Done():
return
case p := <-n.bInbound:
n.sideB.ReceivedPacket(p.seq, p.nack, p.ack)
}
}
}()
go func() {
for {
seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
if err != nil {
return
}
if seq != 0 {
// skip keepalive
// required to ensure AwaitEarlyUpdate terminates
continue
}
p := congestionPacket{
seq: seq,
nack: n.sideB.NextNack(),
ack: n.sideB.NextAck(),
}
n.bOutbound <- p
}
}()
}
func TestNewReno_Congestion(t *testing.T) {
t.Run("OneWay", func(t *testing.T) {
t.Run("Lossless", func(t *testing.T) {
// ASSIGN
rtt := 80 * time.Millisecond
numPackets := 50
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
// ACT
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{
seq: seq,
nack: c.sideA.NextNack(),
ack: c.sideA.NextAck(),
}
}
// allow the systems to catch up before asserting
time.Sleep(rtt + 30*time.Millisecond)
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(0), c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
t.Run("SequenceLoss", func(t *testing.T) {
// ASSIGN
rtt := 80 * time.Millisecond
numPackets := 50
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
// ACT
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
if seq == 20 {
// Simulate packet loss of sequence 20
continue
}
c.aOutbound <- congestionPacket{
seq: seq,
nack: c.sideA.NextNack(),
ack: c.sideA.NextAck(),
}
}
time.Sleep(rtt + 30*time.Millisecond)
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(0), c.sideA.ack)
assert.Equal(t, uint32(20), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
})
t.Run("TwoWay", func(t *testing.T) {
t.Run("Lossless", func(t *testing.T) {
// ASSIGN
rtt := 80 * time.Millisecond
numPackets := 50
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
// ACT
done := make(chan struct{})
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{
seq: seq,
nack: c.sideA.NextNack(),
ack: c.sideA.NextAck(),
}
}
done <- struct{}{}
}()
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx)
c.bOutbound <- congestionPacket{
seq: seq,
nack: c.sideB.NextNack(),
ack: c.sideB.NextAck(),
}
}
done <- struct{}{}
}()
<-done
<-done
time.Sleep(rtt + 30*time.Millisecond)
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
t.Run("SequenceLoss", func(t *testing.T) {
// ASSIGN
rtt := 80 * time.Millisecond
numPackets := 100
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
// ACT
done := make(chan struct{})
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
if seq == 9 {
// Simulate packet loss of sequence 9
continue
}
c.aOutbound <- congestionPacket{
seq: seq,
nack: c.sideA.NextNack(),
ack: c.sideA.NextAck(),
}
}
done <- struct{}{}
}()
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx)
if seq == 13 {
// Simulate packet loss of sequence 13
continue
}
c.bOutbound <- congestionPacket{
seq: seq,
nack: c.sideB.NextNack(),
ack: c.sideB.NextAck(),
}
}
done <- struct{}{}
}()
<-done
<-done
time.Sleep(rtt + 30*time.Millisecond)
// ASSERT
assert.Equal(t, uint32(13), c.sideA.nack)
assert.Equal(t, uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(9), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
})
}
func TestSortableFlights_Less(t *testing.T) {
// ASSIGN
a := []flightInfo{{sequence: 0}, {sequence: 6}, {sequence: 3}, {sequence: 2}}
// ACT
sort.Sort(sortableFlights(a))
// ASSERT
assert.Equal(t, []flightInfo{{sequence: 0}, {sequence: 2}, {sequence: 3}, {sequence: 6}}, a)
}

26
udp/congestion/none.go Normal file
View File

@ -0,0 +1,26 @@
package congestion
import (
"context"
"fmt"
"time"
)
type None struct{}
func NewNone() None {
return None{}
}
func (c None) String() string {
return fmt.Sprintf("{None}")
}
func (c None) ReceivedPacket(uint32, uint32, uint32) {}
func (c None) NextNack() uint32 { return 0 }
func (c None) NextAck() uint32 { return 0 }
func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
<-ctx.Done()
return 0, ctx.Err()
}
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }

306
udp/flow.go Normal file
View File

@ -0,0 +1,306 @@
package udp
import (
"context"
"fmt"
"log"
"mpbl3p/proxy"
"mpbl3p/shared"
"net"
"sync"
"time"
)
type PacketWriter interface {
Write(b []byte) (int, error)
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
LocalAddr() net.Addr
}
type PacketConn interface {
PacketWriter
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
}
type InitiatedFlow struct {
Local func() string
Remote string
g proxy.MacGenerator
keepalive time.Duration
mu sync.RWMutex
Flow
}
func (f *InitiatedFlow) String() string {
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
}
type Flow struct {
writer PacketWriter
raddr *net.UDPAddr
isAlive bool
startup bool
congestion Congestion
v proxy.MacVerifier
inboundDatagrams chan []byte
}
func (f Flow) String() string {
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
}
func InitiateFlow(
local func() string,
remote string,
v proxy.MacVerifier,
g proxy.MacGenerator,
c Congestion,
keepalive time.Duration,
) (*InitiatedFlow, error) {
f := InitiatedFlow{
Local: local,
Remote: remote,
Flow: newFlow(c, v),
g: g,
keepalive: keepalive,
}
return &f, nil
}
func newFlow(c Congestion, v proxy.MacVerifier) Flow {
return Flow{
inboundDatagrams: make(chan []byte),
congestion: c,
v: v,
}
}
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
}
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return err
}
f.writer = conn
f.startup = true
// prod the connection once a second until we get an ack, then consider it alive
go func() {
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return
}
for !f.isAlive {
if ctx.Err() != nil {
return
}
p := Packet{
ack: 0,
nack: 0,
seq: seq,
data: proxy.SimplePacket(nil),
}
_ = f.sendPacket(p, f.g)
time.Sleep(1 * time.Second)
}
}()
go func() {
_, _ = f.produceInternal(ctx, f.v, false)
}()
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
if err := f.readQueuePacket(ctx, conn); err != nil {
return err
}
f.isAlive = true
f.startup = false
go func() {
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
if err := f.readQueuePacket(ctx, conn); err != nil {
log.Println(err)
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
return nil
}
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p, g)
}
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx, v)
}
func (f *Flow) IsAlive() bool {
return f.isAlive
}
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive {
return shared.ErrDeadConnection
}
log.Println(f.congestion)
// Sequence is the congestion controllers opportunity to block
log.Println("awaiting sequence")
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return err
}
log.Println("received sequence")
// Choose up to date ACK/NACK even after blocking
p := Packet{
seq: seq,
data: pp,
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
}
return f.sendPacket(p, g)
}
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
return f.produceInternal(ctx, v, true)
}
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
for once := true; mustReturn || once; once = false {
log.Println(f.congestion)
var received []byte
select {
case received = <-f.inboundDatagrams:
case <-ctx.Done():
return nil, ctx.Err()
}
b, err := proxy.StripMac(received, v)
if err != nil {
return nil, err
}
p, err := UnmarshalPacket(b)
if err != nil {
return nil, err
}
// adjust congestion control based on this packet's congestion header
f.congestion.ReceivedPacket(p.seq, p.nack, p.ack)
// 12 bytes for header + the MAC + a timestamp
if len(b) == 12+f.v.CodeLength()+8 {
log.Println("handled keepalive/ack only packet")
continue
}
return p, nil
}
return nil, nil
}
func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
select {
case f.inboundDatagrams <- p:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
b := p.Marshal()
b = proxy.AppendMac(b, g)
if f.raddr == nil {
_, err := f.writer.Write(b)
return err
} else {
_, err := f.writer.WriteToUDP(b, f.raddr)
return err
}
}
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
for f.isAlive {
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
if err != nil {
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
return
}
p := Packet{
seq: seq,
data: proxy.SimplePacket(nil),
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
}
err = f.sendPacket(p, g)
if err != nil {
fmt.Printf("error sending early update packet: `%v`\n", err)
}
}
}
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf)
if err != nil {
return err
}
return f.queueDatagram(ctx, buf[:n])
}

86
udp/flow_test.go Normal file
View File

@ -0,0 +1,86 @@
package udp
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/mocks"
"mpbl3p/proxy"
"mpbl3p/udp/congestion"
"testing"
"time"
)
func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+8+4, n)
})
}
func TestFlow_Produce(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 42,
nack: 26,
seq: 128,
data: proxy.SimplePacket(testContent),
}
testMac := mocks.AlmostUselessMac{}
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
t.Run("Length", func(t *testing.T) {
done := make(chan struct{})
go func() {
testConn := mocks.NewMockPerfectBiPacketConn(10)
_, err := testConn.SideA().Write(testMarshalled)
require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
go func() {
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
}()
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
done <- struct{}{}
}()
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-done:
case <-timer.C:
fmt.Println("timed out")
t.FailNow()
}
})
}

104
udp/listener.go Normal file
View File

@ -0,0 +1,104 @@
package udp
import (
"context"
"log"
"mpbl3p/proxy"
"net"
"time"
)
type ComparableUdpAddress struct {
IP [16]byte
Port int
Zone string
}
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
var ip [16]byte
for i, b := range []byte(address.IP) {
ip[i] = b
}
return ComparableUdpAddress{
IP: ip,
Port: address.Port,
Zone: address.Zone,
}
}
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error {
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
}
pconn, err := net.ListenUDP("udp", laddr)
if err != nil {
return err
}
err = pconn.SetWriteBuffer(0)
if err != nil {
panic(err)
}
receivedConnections := make(map[ComparableUdpAddress]*Flow)
go func() {
for ctx.Err() == nil {
buf := make([]byte, 6000)
if err := pconn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
panic(err)
}
n, addr, err := pconn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
panic(err)
}
raddr := fromUdpAddress(*addr)
if f, exists := receivedConnections[raddr]; exists {
log.Println("existing flow. queuing...")
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
}
log.Println("queued")
continue
}
v := v()
g := g()
f := newFlow(c(), v)
f.writer = pconn
f.raddr = addr
f.isAlive = true
log.Printf("received new udp connection: %v\n", f)
go f.earlyUpdateLoop(ctx, g, 0)
receivedConnections[raddr] = &f
if enableConsumers {
p.AddConsumer(ctx, &f, g)
}
if enableProducers {
p.AddProducer(ctx, &f, v)
}
log.Println("handling...")
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
return
}
log.Println("handled")
}
}()
return nil
}

43
udp/packet.go Normal file
View File

@ -0,0 +1,43 @@
package udp
import (
"encoding/binary"
"mpbl3p/proxy"
"mpbl3p/shared"
)
type Packet struct {
ack uint32
nack uint32
seq uint32
data proxy.Packet
}
func UnmarshalPacket(b []byte) (p Packet, err error) {
if len(b) < 12 {
return Packet{}, shared.ErrNotEnoughBytes
}
p.ack = binary.LittleEndian.Uint32(b[0:4])
p.nack = binary.LittleEndian.Uint32(b[4:8])
p.seq = binary.LittleEndian.Uint32(b[8:12])
p.data = proxy.SimplePacket(b[12:])
return p, nil
}
func (p Packet) Marshal() []byte {
data := p.data.Marshal()
header := make([]byte, 12)
binary.LittleEndian.PutUint32(header[0:4], p.ack)
binary.LittleEndian.PutUint32(header[4:8], p.nack)
binary.LittleEndian.PutUint32(header[8:12], p.seq)
return append(header, data...)
}
func (p Packet) Contents() []byte {
return p.data.Contents()
}

59
udp/packet_test.go Normal file
View File

@ -0,0 +1,59 @@
package udp
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mpbl3p/proxy"
"testing"
)
func TestPacket_Marshal(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 18,
nack: 29,
seq: 431,
data: proxy.SimplePacket(testContent),
}
t.Run("Length", func(t *testing.T) {
marshalled := testPacket.Marshal()
// 12 header + 8 timestamp
assert.Len(t, marshalled, len(testContent)+12)
})
}
func TestUnmarshalPacket(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := Packet{
ack: 18,
nack: 29,
seq: 431,
data: proxy.SimplePacket(testContent),
}
testMarshalled := testPacket.Marshal()
t.Run("Length", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
})
t.Run("Contents", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Equal(t, p.Contents(), testContent)
})
t.Run("Header", func(t *testing.T) {
p, err := UnmarshalPacket(testMarshalled)
require.Nil(t, err)
assert.Equal(t, p.ack, uint32(18))
assert.Equal(t, p.nack, uint32(29))
assert.Equal(t, p.seq, uint32(431))
})
}

View File

@ -0,0 +1,36 @@
mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP")
ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement")
nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement")
seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number")
time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp")
proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data")
mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F }
function mpbl3p_udp.dissector(buffer, pinfo, tree)
if buffer:len() < 20 then
return
end
pinfo.cols.protocol = "MPBL3P_UDP"
local ack = buffer(0, 4):le_uint()
local nack = buffer(4, 4):le_uint()
local seq = buffer(8, 4):le_uint()
local unix_time = buffer(buffer:len() - 8, 8):le_uint64()
local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack)
subtree:add(ack_F, ack)
subtree:add(nack_F, nack)
subtree:add(seq_F, seq)
subtree:add(time_F, NSTime.new(unix_time:tonumber()))
if buffer:len() > 20 then
subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8))
end
end
DissectorTable.get("udp.port"):add(4724, mpbl3p_udp)
DissectorTable.get("udp.port"):add(4725, mpbl3p_udp)

View File

@ -1,13 +0,0 @@
package utils
var NextId = make(chan int)
func init() {
go func() {
i := 0
for {
NextId <- i
i += 1
}
}()
}