diff --git a/.drone.yml b/.drone.yml index c5d91e0..d5701b8 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,12 +1,18 @@ +--- kind: pipeline type: docker name: default steps: + - name: format + image: golang:1.16 + commands: + - bash -c "gofmt -l . | wc -l | cmp -s <(echo 0) || (gofmt -l . && exit 1)" + - name: install - image: golang:1.15 + image: golang:1.16 environment: - GOPROXY: http://10.20.0.25:3142|direct + GOPROXY: http://containers.internal.hillion.co.uk:3142,direct volumes: - name: cache path: /go @@ -14,7 +20,7 @@ steps: - go test -i ./... - name: test - image: golang:1.15 + image: golang:1.16 volumes: - name: cache path: /go @@ -22,7 +28,7 @@ steps: - go test ./... - name: build (debian) - image: golang:1.15-buster + image: golang:1.16-buster when: event: - push @@ -30,7 +36,10 @@ steps: - name: cache path: /go commands: - - go build + - GOOS=linux GOARCH=amd64 go build -o linux_amd64 + - GOOS=linux GOARCH=arm GOARM=7 go build -o linux_arm_v7 + - GOOS=freebsd GOARCH=amd64 go build -o freebsd_amd64 + - GOOS=freebsd GOARCH=arm64 go build -o freebsd_arm64_v8a - name: upload image: minio/mc @@ -42,9 +51,18 @@ steps: SECRET_KEY: from_secret: s3_secret_key commands: - - mc alias set s3 http://10.20.0.25:3900 $${ACCESS_KEY} $${SECRET_KEY} - - mc cp mpbl3p s3/dissertation/binaries/debian/${DRONE_BRANCH} + - mc alias set s3 https://s3.us-west-001.backblazeb2.com $${ACCESS_KEY} $${SECRET_KEY} + - mc cp linux_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_amd64 + - mc cp linux_arm_v7 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_arm_v7 + - mc cp freebsd_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_amd64 + - mc cp freebsd_arm64_v8a s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_arm64_v8a volumes: - name: cache - temp: {} + temp: { } + +--- +kind: signature +hmac: 7960420c7d02f9bce56d6429b612676d24cbe1d1608cf44a77da9afc411eccb8 + +... diff --git a/.gitignore b/.gitignore index 7d235d7..42e21ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -config.ini +*.conf logs/ # Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go diff --git a/README.md b/README.md index a4292f6..7a67939 100644 --- a/README.md +++ b/README.md @@ -4,15 +4,130 @@ ### Linux #### Policy Based Routing - ip route flush 11 - ip route add table 11 to 1.1.1.0/24 dev eth1 - ip rule add from 1.1.1.4 table 11 priority 11 - - ip route flush 10 - ip route add table 10 to 1.1.1.0/24 dev eth2 - ip rule add from 1.1.1.5 table 10 priority 10 + ip route flush table 10 + ip route add table 10 to 1.1.1.0/24 dev eth1 + ip rule add from 1.1.1.4 table 10 priority 10 + + ip route flush table 11 + ip route add table 11 to 1.1.1.0/24 dev eth2 + ip rule add from 1.1.1.5 table 11 priority 11 #### ARP Flux sysctl -w net.ipv4.conf.all.arp_announce=1 - sysctl -w net.ipv4.conf.all.arp_ignore=2 + sysctl -w net.ipv4.conf.all.arp_ignore=1 + +See http://kb.linuxvirtualserver.org/wiki/Using_arp_announce/arp_ignore_to_disable_ARP + +### Setup Scripts +These are functional setup scripts that make the application run as intended on Linux. + +### Remote Portal +#### Pre-Start + + #!/bin/bash + set -e + + ## Set up variables + REMOTE_PORTAL_ADDRESS=A.B.C.D + + ## IPv4 Forwarding + sysctl -w net.ipv4.ip_forward=1 + sysctl -w net.ipv4.conf.eth0.proxy_arp=1 + + ## Transfer the local routing table to a much lower priority + (ip rule show | grep '20:') > /dev/null || ip rule add from all table local priority 20 + ip rule del priority 0 2> /dev/null || true + + ## Ports to route locally + + ### MPBL3P + ip rule del priority 1 2> /dev/null || true + ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 1234 table local priority 1 + + ### SSH + ip rule del priority 2 2> /dev/null || true + ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 22 table local priority 2 + +#### Post-Start + + #!/bin/bash + set -e + + ## Set up variables + REMOTE_PORTAL_ADDRESS=A.B.C.D + + ## Tunnel addr/up + ip addr add 172.19.152.2/31 dev nc0 + ip link set up nc0 + + # Route packets to the interface but not for nc via the tunnel + ip route flush table 19 + ip route add table 19 to "$REMOTE_PORTAL_ADDRESS" via 172.19.152.3 dev nc0 + ip rule del priority 19 2> /dev/null || true + ip rule add to "$REMOTE_PORTAL_ADDRESS" table 19 priority 19 + +### Local Portal +#### Pre-Start + + #!/bin/bash + set -e + + ## Set up variables + GATEWAY_INTERFACE=eth0 + GATEWAY_ADDRESS=10.36.12.1 + + ## Fix ARP + sysctl -w net.ipv4.conf.all.arp_announce=1 + sysctl -w net.ipv4.conf.all.arp_ignore=1 + + ## IPv4 Forwarding + sysctl -w net.ipv4.ip_forward=1 + + ## Gateway Interface Setup + ip addr flush dev "$GATEWAY_INTERFACE" + ip addr add "$GATEWAY_ADDRESS"/32 dev "$GATEWAY_INTERFACE" + ip link set up "$GATEWAY_INTERFACE" + + ## Per-Interface Routing Tables + + ### 10.10.0.0/24 + ip route flush table 10 + ip route add table 10 default via 10.10.0.1 + ip rule del priority 10 2> /dev/null || true + ip rule add from 10.10.0.0/24 table 10 priority 10 + + ### 192.168.0.0/24 + ip route flush table 11 + ip route add table 11 default via 192.168.0.1 + ip rule del priority 11 2> /dev/null || true + ip rule add from 192.168.0.0/24 table 11 priority 11 + +#### Post-Start + + #!/bin/bash + set -e + + ## Set up variables + REMOTE_PORTAL_ADDRESS=A.B.C.D + GATEWAY_INTERFACE=eth0 + + ## Tunnel Address and Enable + ip addr add 172.19.152.3/31 dev nc0 + ip link set up nc0 + + ## Route Outbound Packets Correctly + ip route flush table 20 + ip route add table 20 default via 172.19.152.2 dev nc0 + ip rule del priority 20 2> /dev/null || true + ip rule add from "$REMOTE_PORTAL_ADDRESS" iif "$GATEWAY_INTERFACE" table 20 priority 20 + + ## Route Inbound Packets Correctly + ip route flush table 21 + ip route add table 21 to "$REMOTE_PORTAL_ADDRESS" dev "$GATEWAY_INTERFACE" + ip rule del priority 21 2> /dev/null || true + ip rule add to "$REMOTE_PORTAL_ADDRESS" table 21 priority 21 + +#### Client + +Connect to `GATEWAY_INTERFACE` and set the IP to `REMOTE_PORTAL_ADDRESS`/32 with a gateway of `GATEWAY_ADDRESS`. diff --git a/config/builder.go b/config/builder.go index 6f94d8b..c5bea4f 100644 --- a/config/builder.go +++ b/config/builder.go @@ -1,48 +1,57 @@ package config import ( + "context" + "encoding/base64" "fmt" + "mpbl3p/crypto" + "mpbl3p/crypto/sharedkey" "mpbl3p/proxy" "mpbl3p/tcp" - "mpbl3p/tun" + "mpbl3p/udp" + "mpbl3p/udp/congestion" + "time" ) -// TODO: Delete this code as soon as an alternative is available -type UselessMac struct{} - -func (UselessMac) CodeLength() int { - return 0 -} - -func (UselessMac) Generate([]byte) []byte { - return nil -} - -func (u UselessMac) Verify([]byte, []byte) error { - return nil -} - -func (c Configuration) Build() (*proxy.Proxy, error) { +func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { p := proxy.NewProxy(0) - p.Generator = UselessMac{} - if c.Host.InterfaceName == "" { - c.Host.InterfaceName = "nc%d" + var g func() proxy.MacGenerator + var v func() proxy.MacVerifier + + switch c.Host.Crypto { + case "None": + g = func() proxy.MacGenerator { return crypto.None{} } + v = func() proxy.MacVerifier { return crypto.None{} } + case "Blake2s": + key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey) + if err != nil { + return nil, err + } + if _, err := sharedkey.NewBlake2s(key); err != nil { + return nil, err + } + g = func() proxy.MacGenerator { + g, _ := sharedkey.NewBlake2s(key) + return g + } + v = func() proxy.MacVerifier { + v, _ := sharedkey.NewBlake2s(key) + return v + } } - ss, err := tun.NewTun(c.Host.InterfaceName, 1500) - if err != nil { - return nil, err - } - - p.Source = ss - p.Sink = ss + p.Source = source + p.Sink = sink for _, peer := range c.Peers { switch peer.Method { case "TCP": - err := buildTcp(p, peer) - if err != nil { + if err := buildTcp(ctx, p, peer, g, v); err != nil { + return nil, err + } + case "UDP": + if err := buildUdp(ctx, p, peer, g, v); err != nil { return nil, err } } @@ -51,20 +60,82 @@ func (c Configuration) Build() (*proxy.Proxy, error) { return p, nil } -func buildTcp(p *proxy.Proxy, peer Peer) error { - if peer.RemoteHost != "" { - f, err := tcp.InitiateFlow( - fmt.Sprintf("%s:", peer.LocalHost), - fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), - ) - - p.AddConsumer(f) - p.AddProducer(f, UselessMac{}) - - return err +func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { + var laddr func() string + if peer.LocalPort == 0 { + laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } + } else { + laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) } } - err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{}) + if peer.RemoteHost != "" { + f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort)) + + if err != nil { + return err + } + + if !peer.DisableConsumer { + p.AddConsumer(ctx, f, g()) + } + if !peer.DisableProducer { + p.AddProducer(ctx, f, v()) + } + + return nil + } + + err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer) + if err != nil { + return err + } + + return nil +} + +func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { + var laddr func() string + if peer.LocalPort == 0 { + laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } + } else { + laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) } + } + + var c func() udp.Congestion + switch peer.Congestion { + case "None": + c = func() udp.Congestion { return congestion.NewNone() } + default: + fallthrough + case "NewReno": + c = func() udp.Congestion { return congestion.NewNewReno() } + } + + if peer.RemoteHost != "" { + f, err := udp.InitiateFlow( + laddr, + fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), + v(), + g(), + c(), + time.Duration(peer.KeepAlive)*time.Second, + ) + + if err != nil { + return err + } + + if !peer.DisableConsumer { + p.AddConsumer(ctx, f, g()) + } + if !peer.DisableProducer { + p.AddProducer(ctx, f, v()) + } + + return nil + } + + err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer) if err != nil { return err } diff --git a/config/config.go b/config/config.go index 76de54c..ed7bea5 100644 --- a/config/config.go +++ b/config/config.go @@ -1,32 +1,96 @@ package config -import "github.com/go-playground/validator/v10" +import ( + "github.com/go-playground/validator/v10" + "log" + "net" + "strings" +) var v = validator.New() +func init() { + if err := v.RegisterValidation("iface", func(fl validator.FieldLevel) bool { + name, ok := fl.Field().Interface().(string) + if ok && name != "" { + ifaces, err := net.Interfaces() + if err != nil { + log.Printf("error getting interfaces: %v", err) + return false + } + + for _, i := range ifaces { + if i.Name == name { + return true + } + } + } + + return false + }); err != nil { + panic(err) + } +} + type Configuration struct { Host Host Peers []Peer `validate:"dive"` } type Host struct { - PrivateKey string `validate:"required"` - InterfaceName string + Crypto string `validate:"required,oneof=None Blake2s"` + SharedKey string `validate:"required_if=Crypto Blake2s"` + MTU uint `validate:"required,min=576"` } type Peer struct { - PublicKey string `validate:"required"` - Method string `validate:"oneof=TCP"` + Method string `validate:"oneof=TCP UDP"` - LocalHost string `validate:"omitempty,ip"` + LocalHost string `validate:"omitempty,ip|iface"` LocalPort uint `validate:"max=65535"` RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"` RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"` + Congestion string `validate:"required_unless=Method TCP,omitempty,oneof=NewReno None"` + KeepAlive uint Timeout uint RetryWait uint + + DisableConsumer bool `validate:"omitempty,nefield=DisableProducer"` + DisableProducer bool `validate:"omitempty,nefield=DisableConsumer"` +} + +func (p Peer) GetLocalHost() string { + if p.LocalHost == "" { + return "" + } + + if err := v.Var(p.LocalHost, "ip"); err == nil { + return p.LocalHost + } + + iface, err := net.InterfaceByName(p.LocalHost) + if err != nil { + panic(err) + } + + if iface != nil { + addrs, err := iface.Addrs() + if err != nil { + panic(err) + } + + if len(addrs) > 0 { + addr := addrs[0].String() + addr = strings.Split(addr, "/")[0] + log.Printf("resolved interface `%s` to `%v`", p.LocalHost, addr) + return addr + } + } + + return "invalid" } func (c Configuration) Validate() error { diff --git a/crypto/none.go b/crypto/none.go new file mode 100644 index 0000000..82a74f2 --- /dev/null +++ b/crypto/none.go @@ -0,0 +1,15 @@ +package crypto + +type None struct{} + +func (None) CodeLength() int { + return 0 +} + +func (None) Generate([]byte) []byte { + return nil +} + +func (None) Verify([]byte, []byte) error { + return nil +} diff --git a/crypto/sharedkey/blake2s.go b/crypto/sharedkey/blake2s.go new file mode 100644 index 0000000..3ded854 --- /dev/null +++ b/crypto/sharedkey/blake2s.go @@ -0,0 +1,40 @@ +package sharedkey + +import ( + "bytes" + "golang.org/x/crypto/blake2s" + "mpbl3p/shared" +) + +type Blake2s struct { + key []byte +} + +func NewBlake2s(key []byte) (*Blake2s, error) { + _, err := blake2s.New128(key) + if err != nil { + return nil, err + } + + return &Blake2s{key: key}, nil +} + +func (b Blake2s) CodeLength() int { + return blake2s.Size128 +} + +func (b Blake2s) Generate(d []byte) []byte { + h, _ := blake2s.New128(b.key) + h.Write(d) + return h.Sum([]byte{}) +} + +func (b Blake2s) Verify(d []byte, s []byte) error { + h, _ := blake2s.New128(b.key) + h.Write(d) + sum := h.Sum([]byte{}) + if !bytes.Equal(sum, s) { + return shared.ErrBadChecksum + } + return nil +} diff --git a/crypto/sharedkey/blake2s_test.go b/crypto/sharedkey/blake2s_test.go new file mode 100644 index 0000000..f6cc546 --- /dev/null +++ b/crypto/sharedkey/blake2s_test.go @@ -0,0 +1,45 @@ +package sharedkey + +import ( + "github.com/stretchr/testify/assert" + "math/rand" + "mpbl3p/shared" + "testing" +) + +func TestBlake2s_GenerateVerify(t *testing.T) { + t.Run("GeneratedVerifies", func(t *testing.T) { + // ASSIGN + key := make([]byte, 16) + rand.Read(key) + buf := make([]byte, 500) + rand.Read(buf) + + // ACT + b := Blake2s{key} + code := b.Generate(buf) + + // ASSERT + err := b.Verify(buf, code) + assert.Nil(t, err) + }) + + t.Run("FlippedBitFailsVerify", func(t *testing.T) { + // ASSIGN + key := make([]byte, 16) + rand.Read(key) + buf := make([]byte, 500) + rand.Read(buf) + + // ACT + b := Blake2s{key} + code := b.Generate(buf) + + offset := rand.Intn(len(buf) * 8) + buf[offset/8] ^= 1 << (offset % 8) + + // ASSERT + err := b.Verify(buf, code) + assert.Equal(t, shared.ErrBadChecksum, err) + }) +} diff --git a/flags/flags.go b/flags/flags.go new file mode 100644 index 0000000..236d72a --- /dev/null +++ b/flags/flags.go @@ -0,0 +1,41 @@ +package flags + +import ( + "errors" + "fmt" + goflags "github.com/jessevdk/go-flags" + "os" +) + +var PrintedHelpErr = goflags.ErrHelp +var NotEnoughArgs = errors.New("not enough arguments") + +type Options struct { + Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"` + ConfigFile string `short:"c" long:"config" description:"Configuration file location" value-name:"FILE"` + PidFile string `short:"p" long:"pid" description:"PID file location"` + + Positional struct { + InterfaceName string `required:"yes" positional-arg-name:"INTERFACE-NAME" description:"Interface name"` + } `positional-args:"yes"` +} + +func ParseFlags() (*Options, error) { + o := new(Options) + parser := goflags.NewParser(o, goflags.Default) + + _, err := parser.Parse() + if err != nil { + parser.WriteHelp(os.Stdout) + return nil, err + } + + if o.ConfigFile == "" { + o.ConfigFile = fmt.Sprintf(DefaultConfigFile, o.Positional.InterfaceName) + } + if o.PidFile == "" { + o.PidFile = fmt.Sprintf(DefaultPidFile, o.Positional.InterfaceName) + } + + return o, nil +} diff --git a/flags/locs_freebsd.go b/flags/locs_freebsd.go new file mode 100644 index 0000000..13bbda0 --- /dev/null +++ b/flags/locs_freebsd.go @@ -0,0 +1,4 @@ +package flags + +const DefaultConfigFile = "/usr/local/etc/netcombiner/%s" +const DefaultPidFile = "/var/run/netcombiner/%s.pid" diff --git a/flags/locs_linux.go b/flags/locs_linux.go new file mode 100644 index 0000000..82a7f21 --- /dev/null +++ b/flags/locs_linux.go @@ -0,0 +1,4 @@ +package flags + +const DefaultConfigFile = "/etc/netcombiner/%s" +const DefaultPidFile = "/var/run/netcombiner/%s.pid" diff --git a/go.mod b/go.mod index 22a3f7d..7a73cee 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module mpbl3p go 1.15 require ( - github.com/go-playground/assert/v2 v2.0.1 - github.com/go-playground/validator/v10 v10.4.1 - github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab + github.com/go-playground/validator/v10 v10.6.0 + github.com/jessevdk/go-flags v1.5.0 github.com/smartystreets/goconvey v1.6.4 // indirect github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/net v0.0.0-20210326060303-6b1517762897 // indirect + golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58 gopkg.in/ini.v1 v1.62.0 ) diff --git a/go.sum b/go.sum index 2c6ebc9..a201b0b 100644 --- a/go.sum +++ b/go.sum @@ -8,14 +8,16 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.6.0 h1:UGIt4xR++fD9QrBOoo/ascJfGe3AGHEB9s6COnss4Rk= +github.com/go-playground/validator/v10 v10.6.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab h1:dAXDRtXYxj4sTR5WeRuTFJGH18QMT6AUpUgRwedI6es= -github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab/go.mod h1:N5a/Ll2ZNk5wjiLNW9LIiNtO9RNYcaYmcXSYKMYrlDg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= @@ -26,17 +28,35 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210326060303-6b1517762897 h1:KrsHThm5nFk34YtATK1LsThyGhGbGe1olrte/HInHvs= +golang.org/x/net v0.0.0-20210326060303-6b1517762897/go.mod h1:uSPa2vr4CLtc/ILN5odXGNXS6mhrKVzTaCXzk9m6W3k= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492 h1:Paq34FxTluEPvVyayQqMPgHm+vTOrIifmcYxFBx9TLg= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58 h1:HiPOx0boQr3qv0HkZ4fGLtTXJ5tmkbv0d8UmkNcxdv0= +golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58/go.mod h1:Dz+cq5bnrai9EpgYj4GDof/+qaGzbRWbeaAOs1bUYa0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= diff --git a/main.go b/main.go index d7f2633..db3e4b4 100644 --- a/main.go +++ b/main.go @@ -1,36 +1,155 @@ package main import ( + "context" + "errors" + "fmt" "log" "mpbl3p/config" + "mpbl3p/flags" + "mpbl3p/tun" "os" "os/signal" + "strconv" "syscall" ) +const ( + ENV_NC_TUN_FD = "NC_TUN_FD" + ENV_NC_CONFIG_PATH = "NC_CONFIG_PATH" +) + func main() { log.SetFlags(log.Ldate | log.Ltime | log.Llongfile) + if _, exists := os.LookupEnv(ENV_NC_TUN_FD); !exists { + // we are the parent process + // 1) process arguments + // 2) validate config + // 2) create a tun adapter + // 3) spawn a child + // 4) exit + + o, err := flags.ParseFlags() + if err != nil { + if errors.Is(err, flags.PrintedHelpErr) { + return + } + panic(err) + } + + log.Println("loading config...") + + c, err := config.LoadConfig(o.ConfigFile) + if err != nil { + log.Fatalf("error validating config: %s", err.Error()) + return + } + + log.Println("creating tun adapter...") + t, err := tun.NewTun(o.Positional.InterfaceName, int(c.Host.MTU)) + if err != nil { + panic(err) + } + + if o.Foreground { + if err := os.Setenv(ENV_NC_TUN_FD, fmt.Sprintf("%d", t.File().Fd())); err != nil { + panic(err) + } + if err := os.Setenv(ENV_NC_CONFIG_PATH, o.ConfigFile); err != nil { + panic(err) + } + log.Println("switching to foreground") + goto FOREGROUND + } + + files := make([]*os.File, 4) + files[0], _ = os.Open(os.DevNull) // stdin + files[1], _ = os.Open(os.DevNull) // stderr + files[2], _ = os.Open(os.DevNull) // stdout + files[3] = t.File() + + env := os.Environ() + env = append(env, fmt.Sprintf("%s=3", ENV_NC_TUN_FD)) + env = append(env, fmt.Sprintf("%s=%s", ENV_NC_CONFIG_PATH, o.ConfigFile)) + + attr := os.ProcAttr{ + Env: env, + Files: files, + } + + path, err := os.Executable() + if err != nil { + panic(err) + } + + process, err := os.StartProcess( + path, + os.Args, + &attr, + ) + if err != nil { + panic(err) + } + + pidFile, err := os.Create(o.PidFile) + if err != nil { + panic(err) + } + if _, err := fmt.Fprintf(pidFile, "%d", process.Pid); err != nil { + panic(err) + } + + _ = process.Release() + return + } + + // we are the child process + // 1) recreate tun adapter from file descriptor + // 2) launch proxy + +FOREGROUND: + log.Println("loading config...") - c, err := config.LoadConfig("config.ini") + c, err := config.LoadConfig(os.Getenv(ENV_NC_CONFIG_PATH)) if err != nil { panic(err) } + log.Println("connecting tun adapter...") + tunFd, err := strconv.ParseUint(os.Getenv(ENV_NC_TUN_FD), 10, 64) + if err != nil { + panic(err) + } + + t, err := tun.NewFromFile(uintptr(tunFd), int(c.Host.MTU)) + if err != nil { + panic(err) + } + defer func() { + if err := t.Close(); err != nil { + panic(err) + } + }() + log.Println("building config...") - p, err := c.Build() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p, err := c.Build(ctx, t, t) if err != nil { panic(err) } - log.Println("starting...") + log.Println("starting proxy...") p.Start() - log.Println("running") + log.Println("proxy started") signals := make(chan os.Signal) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) <-signals + log.Println("exiting...") } diff --git a/mocks/conn.go b/mocks/conn.go deleted file mode 100644 index 3ad384f..0000000 --- a/mocks/conn.go +++ /dev/null @@ -1,67 +0,0 @@ -package mocks - -import "time" - -type MockPerfectBiConn struct { - directionA chan byte - directionB chan byte -} - -func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn { - return MockPerfectBiConn{ - directionA: make(chan byte, bufSize), - directionB: make(chan byte, bufSize), - } -} - -func (bc MockPerfectBiConn) SideA() MockPerfectConn { - return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB} -} - -func (bc MockPerfectBiConn) SideB() MockPerfectConn { - return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA} -} - -type MockPerfectConn struct { - inbound chan byte - outbound chan byte -} - -func (c MockPerfectConn) SetWriteDeadline(time.Time) error { - return nil -} - -func (c MockPerfectConn) Read(p []byte) (n int, err error) { - for i := range p { - if i == 0 { - p[i] = <-c.inbound - } else { - select { - case b := <-c.inbound: - p[i] = b - default: - return i, nil - } - } - } - return len(p), nil -} - -func (c MockPerfectConn) Write(p []byte) (n int, err error) { - for _, b := range p { - c.outbound <- b - } - return len(p), nil -} - -func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) { - for i := range p { - select { - case b := <-c.inbound: - p[i] = b - default: - return i, nil - } - } - return len(p), nil -} diff --git a/mocks/mac.go b/mocks/mac.go index 60d2167..5d7ab7a 100644 --- a/mocks/mac.go +++ b/mocks/mac.go @@ -1,6 +1,8 @@ package mocks -import "mpbl3p/shared" +import ( + "mpbl3p/shared" +) type AlmostUselessMac struct{} diff --git a/mocks/packetconn.go b/mocks/packetconn.go new file mode 100644 index 0000000..9360db4 --- /dev/null +++ b/mocks/packetconn.go @@ -0,0 +1,62 @@ +package mocks + +import "net" + +type MockPerfectBiPacketConn struct { + directionA chan []byte + directionB chan []byte +} + +func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn { + return MockPerfectBiPacketConn{ + directionA: make(chan []byte, bufSize), + directionB: make(chan []byte, bufSize), + } +} + +func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn { + return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB} +} + +func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn { + return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA} +} + +type MockPerfectPacketConn struct { + inbound chan []byte + outbound chan []byte +} + +func (c MockPerfectPacketConn) Write(b []byte) (int, error) { + c.outbound <- b + return len(b), nil +} + +func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + c.outbound <- b + return len(b), nil +} + +func (c MockPerfectPacketConn) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + } +} + +func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + p := <-c.inbound + return copy(b, p), &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, nil +} + +func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) { + select { + case b := <-c.inbound: + return copy(p, b), nil + default: + return 0, nil + } +} diff --git a/mocks/streamconn.go b/mocks/streamconn.go new file mode 100644 index 0000000..956d2d2 --- /dev/null +++ b/mocks/streamconn.go @@ -0,0 +1,95 @@ +package mocks + +import ( + "net" + "time" +) + +type MockPerfectBiStreamConn struct { + directionA chan byte + directionB chan byte +} + +func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn { + return MockPerfectBiStreamConn{ + directionA: make(chan byte, bufSize), + directionB: make(chan byte, bufSize), + } +} + +func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn { + return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB} +} + +func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn { + return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA} +} + +type MockPerfectStreamConn struct { + inbound chan byte + outbound chan byte +} + +type Conn interface { + Read(b []byte) (n int, err error) + Write(b []byte) (n int, err error) + SetWriteDeadline(time.Time) error + + // For printing + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) { + for i := range p { + if i == 0 { + p[i] = <-c.inbound + } else { + select { + case b := <-c.inbound: + p[i] = b + default: + return i, nil + } + } + } + return len(p), nil +} + +func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) { + for _, b := range p { + c.outbound <- b + } + return len(p), nil +} + +func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error { + return nil +} + +// Only used for printing flow information +func (c MockPerfectStreamConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 499, + } +} + +func (c MockPerfectStreamConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 500, + } +} + +func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) { + for i := range p { + select { + case b := <-c.inbound: + p[i] = b + default: + return i, nil + } + } + return len(p), nil +} diff --git a/proxy/packet.go b/proxy/packet.go index 1588278..9e0f592 100644 --- a/proxy/packet.go +++ b/proxy/packet.go @@ -5,56 +5,42 @@ import ( "time" ) -type Packet struct { - Data []byte - timestamp time.Time +type Packet interface { + Marshal() []byte + Contents() []byte } -// create a packet from the raw data of an IP packet -func NewPacket(data []byte) Packet { - return Packet{ - Data: data, - timestamp: time.Now(), - } -} - -// rebuild a packet from the wrapped format -func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) { - // the MAC is the last N bytes - data := raw[:len(raw)-verifier.CodeLength()] - sum := raw[len(raw)-verifier.CodeLength():] - - if err := verifier.Verify(data, sum); err != nil { - return Packet{}, err - } - - p := Packet{ - Data: data[:len(data)-8], - } - - unixTime := int64(binary.LittleEndian.Uint64(data[len(data)-8:])) - p.timestamp = time.Unix(unixTime, 0) - - return p, nil -} +type SimplePacket []byte // get the raw data of the IP packet -func (p Packet) Raw() []byte { - return p.Data +func (p SimplePacket) Marshal() []byte { + return p } -// produce the wrapped format of a packet -func (p Packet) Marshal(generator MacGenerator) []byte { - // length of data + length of timestamp (8 byte) + length of checksum - slice := make([]byte, len(p.Data)+8+generator.CodeLength()) - - copy(slice, p.Data) - - unixTime := uint64(p.timestamp.Unix()) - binary.LittleEndian.PutUint64(slice[len(p.Data):], unixTime) - - mac := generator.Generate(slice) - copy(slice[len(p.Data)+8:], mac) - - return slice +func (p SimplePacket) Contents() []byte { + return p +} + +func AppendMac(b []byte, g MacGenerator) []byte { + footer := make([]byte, 8) + unixTime := uint64(time.Now().Unix()) + binary.LittleEndian.PutUint64(footer, unixTime) + + b = append(b, footer...) + + mac := g.Generate(b) + return append(b, mac...) +} + +func StripMac(b []byte, v MacVerifier) ([]byte, error) { + data := b[:len(b)-v.CodeLength()] + sum := b[len(b)-v.CodeLength():] + + if err := v.Verify(data, sum); err != nil { + return nil, err + } + + // TODO: Verify timestamp + + return data[:len(data)-8], nil } diff --git a/proxy/packet_test.go b/proxy/packet_test.go index 5e14843..af2d626 100644 --- a/proxy/packet_test.go +++ b/proxy/packet_test.go @@ -4,31 +4,59 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "mpbl3p/mocks" + "mpbl3p/shared" "testing" ) -func TestPacket_Marshal(t *testing.T) { +func TestAppendMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := NewPacket(testContent) testMac := mocks.AlmostUselessMac{} + testPacket := SimplePacket(testContent) + testMarshalled := testPacket.Marshal() + + appended := AppendMac(testMarshalled, testMac) t.Run("Length", func(t *testing.T) { - marshalled := testPacket.Marshal(testMac) + assert.Len(t, appended, len(testMarshalled)+8+4) + }) - assert.Len(t, marshalled, len(testContent)+8+4) + t.Run("Mac", func(t *testing.T) { + assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:]) + }) + + t.Run("Original", func(t *testing.T) { + assert.Equal(t, testMarshalled, appended[:len(testMarshalled)]) }) } -func TestUnmarshalPacket(t *testing.T) { +func TestStripMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := NewPacket(testContent) testMac := mocks.AlmostUselessMac{} - testMarshalled := testPacket.Marshal(testMac) + testPacket := SimplePacket(testContent) + testMarshalled := testPacket.Marshal() + + appended := AppendMac(testMarshalled, testMac) t.Run("Length", func(t *testing.T) { - p, err := UnmarshalPacket(testMarshalled, testMac) + cut, err := StripMac(appended, testMac) require.Nil(t, err) - assert.Len(t, p.Raw(), len(testContent)) + assert.Len(t, cut, len(testMarshalled)) + }) + + t.Run("IncorrectMac", func(t *testing.T) { + badMac := make([]byte, len(testMarshalled)+4) + copy(badMac, testMarshalled) + copy(badMac[:len(testMarshalled)], "dcba") + _, err := StripMac(badMac, testMac) + + assert.Error(t, err, shared.ErrBadChecksum) + }) + + t.Run("Original", func(t *testing.T) { + cut, err := StripMac(appended, testMac) + + require.Nil(t, err) + assert.Equal(t, testMarshalled, cut) }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index e98a836..593d65b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,22 +1,24 @@ package proxy import ( + "context" + "errors" "log" "time" ) type Producer interface { IsAlive() bool - Produce(MacVerifier) (Packet, error) + Produce(context.Context, MacVerifier) (Packet, error) } type Consumer interface { IsAlive() bool - Consume(Packet, MacGenerator) error + Consume(context.Context, Packet, MacGenerator) error } type Reconnectable interface { - Reconnect() error + Reconnect(context.Context) error } type Source interface { @@ -31,8 +33,6 @@ type Proxy struct { Source Source Sink Sink - Generator MacGenerator - proxyChan chan Packet sinkChan chan Packet } @@ -67,7 +67,7 @@ func (p Proxy) Start() { }() } -func (p Proxy) AddConsumer(c Consumer) { +func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { go func() { _, reconnectable := c.(Reconnectable) @@ -75,28 +75,42 @@ func (p Proxy) AddConsumer(c Consumer) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", c) - err = c.(Reconnectable).Reconnect() + log.Printf("attempting to connect consumer `%v`\n", c) + err = c.(Reconnectable).Reconnect(ctx) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed consumer `%v` (context)\n", c) + return + } if !once { time.Sleep(time.Second) } } - log.Printf("connected `%v`\n", c) + log.Printf("connected consumer `%v`\n", c) } for c.IsAlive() { - if err := c.Consume(<-p.proxyChan, p.Generator); err != nil { - log.Println(err) - break + select { + case <-ctx.Done(): + log.Printf("closed consumer `%v` (context)\n", c) + return + case packet := <-p.proxyChan: + if err := c.Consume(ctx, packet, g); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed consumer `%v` (context)\n", c) + return + } + log.Println(err) + break + } } } } - log.Printf("closed connection `%v`\n", c) + log.Printf("closed consumer `%v`\n", c) }() } -func (p Proxy) AddProducer(pr Producer, v MacVerifier) { +func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) { go func() { _, reconnectable := pr.(Reconnectable) @@ -104,25 +118,42 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", pr) - err = pr.(Reconnectable).Reconnect() + log.Printf("attempting to connect producer `%v`\n", pr) + err = pr.(Reconnectable).Reconnect(ctx) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed producer `%v` (context)\n", pr) + return + } if !once { time.Sleep(time.Second) } + if ctx.Err() != nil { + return + } + } - log.Printf("connected `%v`\n", pr) + log.Printf("connected producer `%v`\n", pr) } for pr.IsAlive() { - if packet, err := pr.Produce(v); err != nil { + if packet, err := pr.Produce(ctx, v); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed producer `%v` (context)\n", pr) + return + } log.Println(err) break } else { - p.sinkChan <- packet + select { + case <-ctx.Done(): + log.Printf("closed producer `%v` (context)\n", pr) + return + case p.sinkChan <- packet: + } } } } - log.Printf("closed connection `%v`\n", pr) + log.Printf("closed producer `%v`\n", pr) }() } diff --git a/shared/errors.go b/shared/errors.go index aaae8ba..0db0c92 100644 --- a/shared/errors.go +++ b/shared/errors.go @@ -4,3 +4,4 @@ import "errors" var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrDeadConnection = errors.New("the connection is dead") +var ErrNotEnoughBytes = errors.New("not enough bytes") diff --git a/tcp/flow.go b/tcp/flow.go index 26607e8..bf54675 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -1,8 +1,10 @@ package tcp import ( + "bufio" + "context" "encoding/binary" - "errors" + "fmt" "io" "mpbl3p/proxy" "mpbl3p/shared" @@ -11,16 +13,18 @@ import ( "time" ) -var ErrNotEnoughBytes = errors.New("not enough bytes") - type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) SetWriteDeadline(time.Time) error + + // For printing + LocalAddr() net.Addr + RemoteAddr() net.Addr } type InitiatedFlow struct { - Local string + Local func() string Remote string mu sync.RWMutex @@ -28,21 +32,64 @@ type InitiatedFlow struct { Flow } +func (f *InitiatedFlow) String() string { + return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local(), f.Remote) +} + type Flow struct { conn Conn isAlive bool + + toConsume, produced chan []byte + consumeErrors, produceErrors chan error } -func InitiateFlow(local, remote string) (*InitiatedFlow, error) { +func NewFlow() Flow { + return Flow{ + toConsume: make(chan []byte), + produced: make(chan []byte), + consumeErrors: make(chan error), + produceErrors: make(chan error), + } +} + +func NewFlowConn(ctx context.Context, conn Conn) Flow { + f := Flow{ + conn: conn, + isAlive: true, + + toConsume: make(chan []byte), + produced: make(chan []byte), + consumeErrors: make(chan error), + produceErrors: make(chan error), + } + + go f.produceMarshalled(ctx) + go f.consumeMarshalled(ctx) + + return f +} + +func (f Flow) String() string { + return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr()) +} + +func (f *Flow) IsAlive() bool { + return f.isAlive +} + +func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, + + Flow: NewFlow(), } return &f, nil } -func (f *InitiatedFlow) Reconnect() error { +func (f *InitiatedFlow) Reconnect(ctx context.Context) error { f.mu.Lock() defer f.mu.Unlock() @@ -50,7 +97,7 @@ func (f *InitiatedFlow) Reconnect() error { return nil } - localAddr, err := net.ResolveTCPAddr("tcp", f.Local) + localAddr, err := net.ResolveTCPAddr("tcp", f.Local()) if err != nil { return err } @@ -65,93 +112,128 @@ func (f *InitiatedFlow) Reconnect() error { return err } - err = conn.SetWriteBuffer(0) - if err != nil { + if err := conn.SetWriteBuffer(0); err != nil { return err } f.conn = conn f.isAlive = true + + go f.produceMarshalled(ctx) + go f.consumeMarshalled(ctx) + return nil } -func (f *Flow) IsAlive() bool { - return f.isAlive -} - -func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(p, g) + return f.Flow.Consume(ctx, p, g) } -func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { +func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(ctx, v) +} + +func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { if !f.isAlive { return shared.ErrDeadConnection } - data := p.Marshal(g) - err = f.consumeMarshalled(data) - if err != nil { + select { + case err := <-f.consumeErrors: f.isAlive = false + return err + default: } - return -} -func (f *Flow) consumeMarshalled(data []byte) error { + marshalled := p.Marshal() + data := proxy.AppendMac(marshalled, g) + prefixedData := make([]byte, len(data)+4) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) copy(prefixedData[4:], data) - err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err != nil { - return err + select { + case f.toConsume <- prefixedData: + case <-ctx.Done(): + return ctx.Err() } - _, err = f.conn.Write(prefixedData) - return err + + return nil } -func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { - f.mu.RLock() - defer f.mu.RUnlock() - - return f.Flow.Produce(v) -} - -func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { - return proxy.Packet{}, shared.ErrDeadConnection + return nil, shared.ErrDeadConnection } - data, err := f.produceMarshalled() - if err != nil { + var data []byte + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case data = <-f.produced: + case err := <-f.produceErrors: f.isAlive = false - return proxy.Packet{}, err + return nil, err } - return proxy.UnmarshalPacket(data, v) + b, err := proxy.StripMac(data, v) + if err != nil { + return nil, err + } + + return proxy.SimplePacket(b), nil } -func (f *Flow) produceMarshalled() ([]byte, error) { - lengthBytes := make([]byte, 4) - if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil { - return nil, err - } else if n != 4 { - return nil, ErrNotEnoughBytes - } +func (f *Flow) consumeMarshalled(ctx context.Context) { + for { + data := <-f.toConsume - length := binary.LittleEndian.Uint32(lengthBytes) - dataBytes := make([]byte, length) - - var read uint32 - for read < length { - if n, err := io.LimitReader(f.conn, int64(length-read)).Read(dataBytes[read:]); err != nil { - return nil, err - } else { - read += uint32(n) + err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + f.consumeErrors <- err + return + } + _, err = f.conn.Write(data) + if err != nil { + f.consumeErrors <- err + return } } - - return dataBytes, nil +} + +func (f *Flow) produceMarshalled(ctx context.Context) { + buf := bufio.NewReader(f.conn) + + for { + lengthBytes := make([]byte, 4) + if n, err := io.LimitReader(buf, 4).Read(lengthBytes); err != nil { + f.produceErrors <- err + return + } else if n != 4 { + f.produceErrors <- shared.ErrNotEnoughBytes + return + } + + length := binary.LittleEndian.Uint32(lengthBytes) + dataBytes := make([]byte, length) + + var read uint32 + for read < length { + if n, err := io.LimitReader(buf, int64(length-read)).Read(dataBytes[read:]); err != nil { + f.produceErrors <- err + return + } else { + read += uint32(n) + } + } + + f.produced <- dataBytes + } } diff --git a/tcp/flow_test.go b/tcp/flow_test.go index 88fbf1c..eba5a0a 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -1,8 +1,9 @@ package tcp import ( + "context" "encoding/binary" - "github.com/go-playground/assert/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "mpbl3p/mocks" "mpbl3p/proxy" @@ -11,15 +12,15 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := proxy.NewPacket(testContent) + testPacket := proxy.SimplePacket(testContent) testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := Flow{conn: testConn.SideA(), isAlive: true} + flowA := NewFlowConn(context.Background(), testConn.SideA()) - err := flowA.Consume(testPacket, testMac) + err := flowA.Consume(context.Background(), testPacket, testMac) require.Nil(t, err) buf := make([]byte, 100) @@ -39,28 +40,28 @@ func TestFlow_Produce(t *testing.T) { testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := Flow{conn: testConn.SideA(), isAlive: true} + flowA := NewFlowConn(context.Background(), testConn.SideA()) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(testMac) + p, err := flowA.Produce(context.Background(), testMac) require.Nil(t, err) - assert.Equal(t, len(testContent), len(p.Raw())) + assert.Equal(t, len(testContent), len(p.Contents())) }) t.Run("Value", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := Flow{conn: testConn.SideA(), isAlive: true} + flowA := NewFlowConn(context.Background(), testConn.SideA()) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(testMac) + p, err := flowA.Produce(context.Background(), testMac) require.Nil(t, err) - assert.Equal(t, testContent, string(p.Raw())) + assert.Equal(t, testContent, string(p.Contents())) }) } diff --git a/tcp/listener.go b/tcp/listener.go index 5da9761..cf87400 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -1,12 +1,13 @@ package tcp import ( + "context" "log" "mpbl3p/proxy" "net" ) -func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error { laddr, err := net.ResolveTCPAddr("tcp", local) if err != nil { return err @@ -24,17 +25,20 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { panic(err) } - err = conn.SetWriteBuffer(0) - if err != nil { + if err := conn.SetWriteBuffer(0); err != nil { panic(err) } - f := Flow{conn: conn, isAlive: true} + f := NewFlowConn(ctx, conn) - log.Printf("received new connection: %v\n", f) + log.Printf("received new tcp connection: %v\n", f) - p.AddConsumer(&f) - p.AddProducer(&f, v) + if enableConsumers { + p.AddConsumer(ctx, &f, g()) + } + if enableProducers { + p.AddProducer(ctx, &f, v()) + } } }() diff --git a/tun/tun.go b/tun/tun.go index 2f5c5cb..252140e 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -1,85 +1,65 @@ package tun import ( - "github.com/pkg/taptun" + wgtun "golang.zx2c4.com/wireguard/tun" "io" "log" "mpbl3p/proxy" - "net" "os" - "strings" - "sync" - "time" ) type SourceSink struct { - tun *taptun.Tun - bufferSize int - - up bool - upMu sync.Mutex + tun wgtun.Device } -func NewTun(namingScheme string, bufferSize int) (ss *SourceSink, err error) { - ss = &SourceSink{} +func NewTun(name string, mtu int) (t wgtun.Device, err error) { + return wgtun.CreateTUN(name, mtu) +} - ss.tun, err = taptun.NewTun(namingScheme) +func NewFromFile(fd uintptr, mtu int) (ss *SourceSink, err error) { + ss = new(SourceSink) + + file := os.NewFile(fd, "") + ss.tun, err = wgtun.CreateTUNFromFile(file, mtu) if err != nil { return } - ss.bufferSize = bufferSize - - ss.upMu.Lock() - go func() { - defer ss.upMu.Unlock() - - for { - iface, err := net.InterfaceByName(ss.tun.String()) - if err != nil { - panic(err) - } - if strings.Contains(iface.Flags.String(), "up") { - log.Println("tun is up") - ss.up = true - return - } - time.Sleep(100 * time.Millisecond) - } - }() - return } +func (t *SourceSink) Close() error { + return t.tun.Close() +} + func (t *SourceSink) Source() (proxy.Packet, error) { - if !t.up { - t.upMu.Lock() - t.upMu.Unlock() + mtu, err := t.tun.MTU() + if err != nil { + return nil, err } - buf := make([]byte, t.bufferSize) + buf := make([]byte, mtu+4) - read, err := t.tun.Read(buf) + read, err := t.tun.Read(buf, 4) if err != nil { - return proxy.Packet{}, err + return nil, err } if read == 0 { - return proxy.Packet{}, io.EOF + return nil, io.EOF } - return proxy.NewPacket(buf[:read]), nil + return proxy.SimplePacket(buf[4 : read+4]), nil } var good, bad float64 func (t *SourceSink) Sink(packet proxy.Packet) error { - if !t.up { - t.upMu.Lock() - t.upMu.Unlock() - } + // make space for tun header + content := make([]byte, len(packet.Contents())+4) + copy(content[4:], packet.Contents()) - _, err := t.tun.Write(packet.Raw()) + _, err := t.tun.Write(content, 4) if err != nil { switch err.(type) { case *os.PathError: diff --git a/udp/congestion.go b/udp/congestion.go new file mode 100644 index 0000000..239645b --- /dev/null +++ b/udp/congestion.go @@ -0,0 +1,16 @@ +package udp + +import ( + "context" + "time" +) + +type Congestion interface { + Sequence(ctx context.Context) (uint32, error) + NextAck() uint32 + NextNack() uint32 + + ReceivedPacket(seq, nack, ack uint32) + + AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) +} diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go new file mode 100644 index 0000000..d796e11 --- /dev/null +++ b/udp/congestion/newreno.go @@ -0,0 +1,284 @@ +package congestion + +import ( + "context" + "fmt" + "math" + "sort" + "sync" + "sync/atomic" + "time" +) + +const RttExponentialFactor = 0.1 +const RttLossDelay = 1.5 + +type NewReno struct { + sequence chan uint32 + + inFlight []flightInfo + lastSent time.Time + inFlightMu sync.Mutex + + awaitingAck sortableFlights + ack, nack uint32 + lastAck, lastNack uint32 + ackNackMu sync.Mutex + + rttNanos float64 + windowSize, windowCount uint32 + slowStart bool + windowNotifier chan struct{} +} + +type flightInfo struct { + time time.Time + sequence uint32 +} + +type sortableFlights []flightInfo + +func (f sortableFlights) Len() int { + return len(f) +} + +func (f sortableFlights) Swap(i, j int) { + f[i], f[j] = f[j], f[i] +} + +func (f sortableFlights) Less(i, j int) bool { + return f[i].sequence < f[j].sequence +} + +func (c *NewReno) String() string { + return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, len(c.inFlight), c.lastAck, c.lastNack) +} + +func NewNewReno() *NewReno { + c := NewReno{ + sequence: make(chan uint32), + windowNotifier: make(chan struct{}), + + windowSize: 1, + rttNanos: float64((10 * time.Millisecond).Nanoseconds()), + slowStart: true, + } + + go func() { + var s uint32 + for { + if s == 0 { + s++ + continue + } + + c.sequence <- s + s++ + } + }() + + return &c +} + +func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) { + // decide what acks and nacks to send + if seq != 0 { + c.receivedSequence(seq) + } + + // decide how window size was affected + if nack != 0 { + c.receivedNack(nack) + } + if ack != 0 { + c.receivedAck(ack) + } +} + +func (c *NewReno) Sequence(ctx context.Context) (uint32, error) { + for len(c.inFlight) >= int(c.windowSize) { + <-c.windowNotifier + } + + c.inFlightMu.Lock() + defer c.inFlightMu.Unlock() + + var s uint32 + select { + case s = <-c.sequence: + case <-ctx.Done(): + return 0, ctx.Err() + } + + t := time.Now() + + c.inFlight = append(c.inFlight, flightInfo{ + time: t, + sequence: s, + }) + c.lastSent = t + + return s, nil +} + +func (c *NewReno) NextAck() uint32 { + a := c.ack + atomic.StoreUint32(&c.lastAck, a) + return a +} + +func (c *NewReno) NextNack() uint32 { + n := c.nack + atomic.StoreUint32(&c.lastNack, n) + return n +} + +func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) { + for { + rtt := time.Duration(math.Round(c.rttNanos)) + time.Sleep(rtt / 2) + + c.checkNack() + + // CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT + // this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs) + if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) { + return 0, nil // no ack needed + } + + // CASE 2: No message sent within the keepalive time + if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { + return c.Sequence(ctx) // require an ack + } + } +} + +func (c *NewReno) receivedSequence(seq uint32) { + c.ackNackMu.Lock() + defer c.ackNackMu.Unlock() + + if seq < c.nack || seq < c.ack { + // packet received out of order has already been cumulatively NACKed + // or duplicate packet received and already ACKed + return + } + + if seq != c.ack+1 && seq != c.nack+1 { + c.awaitingAck = append(c.awaitingAck, flightInfo{ + time: time.Now(), + sequence: seq, + }) + return // if this seq doesn't change the ack field, updateAck will still not do anything useful + } + + sort.Sort(c.awaitingAck) + c.updateAck(seq) +} + +func (c *NewReno) checkNack() { + c.ackNackMu.Lock() + defer c.ackNackMu.Unlock() + + if len(c.awaitingAck) == 0 { + return + } + + sort.Sort(c.awaitingAck) + + lossThreshold := time.Duration(c.rttNanos * RttLossDelay) + if c.awaitingAck[0].time.Before(time.Now().Add(-lossThreshold)) { + // if the next packet sequence to ack was received more than an rttlossdelay ago + // mark the packet(s) blocking it as missing with a nack + // then update ack from the delayed packet + c.nack = c.awaitingAck[0].sequence - 1 + c.updateAck(c.nack) + } +} + +func (c *NewReno) updateAck(start uint32) { + a := start + + var removed int + for _, e := range c.awaitingAck { + if e.sequence == a+1 { + a = e.sequence + removed++ + } else { + break + } + } + + c.ack = a + c.awaitingAck = c.awaitingAck[removed:] +} + +func (c *NewReno) receivedNack(nack uint32) { + c.ackNackMu.Lock() + defer c.ackNackMu.Unlock() + + c.inFlightMu.Lock() + defer c.inFlightMu.Unlock() + + // as both ack and nack are cumulative, inflight will always be ordered by seq + var i int + for i < len(c.inFlight) && c.inFlight[i].sequence <= nack { + i++ + } + + if i == 0 { + return + } + + c.slowStart = false + c.inFlight = c.inFlight[i:] + + for { + s := c.windowSize + if s == 1 || atomic.CompareAndSwapUint32(&c.windowSize, s, s/2) { + break + } + } + + select { + case c.windowNotifier <- struct{}{}: + default: + } +} + +func (c *NewReno) receivedAck(ack uint32) { + c.ackNackMu.Lock() + defer c.ackNackMu.Unlock() + + c.inFlightMu.Lock() + defer c.inFlightMu.Unlock() + + // as both ack and nack are cumulative, inflight will always be ordered by seq + var i int + for i < len(c.inFlight) && c.inFlight[i].sequence <= ack { + rtt := float64(time.Now().Sub(c.inFlight[i].time).Nanoseconds()) + c.rttNanos = c.rttNanos*(1-RttExponentialFactor) + rtt*RttExponentialFactor + + i++ + } + + if i == 0 { + return + } + + c.inFlight = c.inFlight[i:] + if c.slowStart { + atomic.AddUint32(&c.windowSize, uint32(i)) + } else { + c.windowCount += uint32(i) + s := c.windowSize + if c.windowCount > s { + c.windowCount -= s + atomic.AddUint32(&c.windowSize, 1) + } + } + + select { + case c.windowNotifier <- struct{}{}: + default: + } +} diff --git a/udp/congestion/newreno_test.go b/udp/congestion/newreno_test.go new file mode 100644 index 0000000..f97dfb9 --- /dev/null +++ b/udp/congestion/newreno_test.go @@ -0,0 +1,369 @@ +package congestion + +import ( + "context" + "github.com/stretchr/testify/assert" + "sort" + "testing" + "time" +) + +type congestionPacket struct { + seq, nack, ack uint32 +} + +type newRenoTest struct { + sideA, sideB *NewReno + + aOutbound, bOutbound chan congestionPacket + aInbound, bInbound chan congestionPacket + + halfRtt time.Duration +} + +func newNewRenoTest(rtt time.Duration) *newRenoTest { + return &newRenoTest{ + sideA: NewNewReno(), + sideB: NewNewReno(), + + aOutbound: make(chan congestionPacket), + bOutbound: make(chan congestionPacket), + + aInbound: make(chan congestionPacket), + bInbound: make(chan congestionPacket), + + halfRtt: rtt / 2, + } +} + +func (n *newRenoTest) Start(ctx context.Context) { + type packetWithTime struct { + t time.Time + p congestionPacket + } + + go func() { + aOutboundDelayed := make(chan packetWithTime, 128) + bOutboundDelayed := make(chan packetWithTime, 128) + + delayer := func(tp chan packetWithTime, cp chan congestionPacket) { + for { + select { + case p := <-tp: + s := p.t.Add(n.halfRtt).Sub(time.Now()) + time.Sleep(s) + cp <- p.p + case <-ctx.Done(): + return + } + } + } + + go delayer(aOutboundDelayed, n.bInbound) + go delayer(bOutboundDelayed, n.aInbound) + + for { + select { + case <-ctx.Done(): + return + case p := <-n.aOutbound: + aOutboundDelayed <- packetWithTime{t: time.Now(), p: p} + case p := <-n.bOutbound: + bOutboundDelayed <- packetWithTime{t: time.Now(), p: p} + } + } + }() +} + +func (n *newRenoTest) RunSideA(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case p := <-n.aInbound: + n.sideA.ReceivedPacket(p.seq, p.nack, p.ack) + } + } + }() + + go func() { + for { + seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500*time.Millisecond) + if err != nil { + return + } + if seq != 0 { + // skip keepalive + // required to ensure AwaitEarlyUpdate terminates + continue + } + p := congestionPacket{ + seq: seq, + nack: n.sideA.NextNack(), + ack: n.sideA.NextAck(), + } + n.aOutbound <- p + } + }() +} + +func (n *newRenoTest) RunSideB(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case p := <-n.bInbound: + n.sideB.ReceivedPacket(p.seq, p.nack, p.ack) + } + } + }() + + go func() { + for { + seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500*time.Millisecond) + if err != nil { + return + } + if seq != 0 { + // skip keepalive + // required to ensure AwaitEarlyUpdate terminates + continue + } + p := congestionPacket{ + seq: seq, + nack: n.sideB.NextNack(), + ack: n.sideB.NextAck(), + } + n.bOutbound <- p + } + }() +} + +func TestNewReno_Congestion(t *testing.T) { + t.Run("OneWay", func(t *testing.T) { + t.Run("Lossless", func(t *testing.T) { + // ASSIGN + rtt := 80 * time.Millisecond + numPackets := 50 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := newNewRenoTest(rtt) + c.Start(ctx) + c.RunSideA(ctx) + c.RunSideB(ctx) + + // ACT + for i := 0; i < numPackets; i++ { + // sleep to simulate preparing packet + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideA.Sequence(ctx) + + c.aOutbound <- congestionPacket{ + seq: seq, + nack: c.sideA.NextNack(), + ack: c.sideA.NextAck(), + } + } + + // allow the systems to catch up before asserting + time.Sleep(rtt + 30*time.Millisecond) + + // ASSERT + + assert.Equal(t, uint32(0), c.sideA.nack) + assert.Equal(t, uint32(0), c.sideA.ack) + + assert.Equal(t, uint32(0), c.sideB.nack) + assert.Equal(t, uint32(numPackets), c.sideB.ack) + }) + + t.Run("SequenceLoss", func(t *testing.T) { + // ASSIGN + rtt := 80 * time.Millisecond + numPackets := 50 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := newNewRenoTest(rtt) + c.Start(ctx) + c.RunSideA(ctx) + c.RunSideB(ctx) + + // ACT + for i := 0; i < numPackets; i++ { + // sleep to simulate preparing packet + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideA.Sequence(ctx) + + if seq == 20 { + // Simulate packet loss of sequence 20 + continue + } + + c.aOutbound <- congestionPacket{ + seq: seq, + nack: c.sideA.NextNack(), + ack: c.sideA.NextAck(), + } + } + + time.Sleep(rtt + 30*time.Millisecond) + + // ASSERT + + assert.Equal(t, uint32(0), c.sideA.nack) + assert.Equal(t, uint32(0), c.sideA.ack) + + assert.Equal(t, uint32(20), c.sideB.nack) + assert.Equal(t, uint32(numPackets), c.sideB.ack) + }) + }) + + t.Run("TwoWay", func(t *testing.T) { + t.Run("Lossless", func(t *testing.T) { + // ASSIGN + rtt := 80 * time.Millisecond + numPackets := 50 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := newNewRenoTest(rtt) + c.Start(ctx) + c.RunSideA(ctx) + c.RunSideB(ctx) + + // ACT + done := make(chan struct{}) + + go func() { + for i := 0; i < numPackets; i++ { + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideA.Sequence(ctx) + + c.aOutbound <- congestionPacket{ + seq: seq, + nack: c.sideA.NextNack(), + ack: c.sideA.NextAck(), + } + } + + done <- struct{}{} + }() + + go func() { + for i := 0; i < numPackets; i++ { + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideB.Sequence(ctx) + + c.bOutbound <- congestionPacket{ + seq: seq, + nack: c.sideB.NextNack(), + ack: c.sideB.NextAck(), + } + } + + done <- struct{}{} + }() + + <-done + <-done + + time.Sleep(rtt + 30*time.Millisecond) + + // ASSERT + + assert.Equal(t, uint32(0), c.sideA.nack) + assert.Equal(t, uint32(numPackets), c.sideA.ack) + + assert.Equal(t, uint32(0), c.sideB.nack) + assert.Equal(t, uint32(numPackets), c.sideB.ack) + }) + + t.Run("SequenceLoss", func(t *testing.T) { + // ASSIGN + rtt := 80 * time.Millisecond + numPackets := 100 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := newNewRenoTest(rtt) + c.Start(ctx) + c.RunSideA(ctx) + c.RunSideB(ctx) + + // ACT + done := make(chan struct{}) + + go func() { + for i := 0; i < numPackets; i++ { + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideA.Sequence(ctx) + + if seq == 9 { + // Simulate packet loss of sequence 9 + continue + } + + c.aOutbound <- congestionPacket{ + seq: seq, + nack: c.sideA.NextNack(), + ack: c.sideA.NextAck(), + } + } + + done <- struct{}{} + }() + + go func() { + for i := 0; i < numPackets; i++ { + time.Sleep(1 * time.Millisecond) + seq, _ := c.sideB.Sequence(ctx) + + if seq == 13 { + // Simulate packet loss of sequence 13 + continue + } + + c.bOutbound <- congestionPacket{ + seq: seq, + nack: c.sideB.NextNack(), + ack: c.sideB.NextAck(), + } + } + + done <- struct{}{} + }() + + <-done + <-done + + time.Sleep(rtt + 30*time.Millisecond) + + // ASSERT + + assert.Equal(t, uint32(13), c.sideA.nack) + assert.Equal(t, uint32(numPackets), c.sideA.ack) + + assert.Equal(t, uint32(9), c.sideB.nack) + assert.Equal(t, uint32(numPackets), c.sideB.ack) + }) + }) +} + +func TestSortableFlights_Less(t *testing.T) { + // ASSIGN + a := []flightInfo{{sequence: 0}, {sequence: 6}, {sequence: 3}, {sequence: 2}} + + // ACT + sort.Sort(sortableFlights(a)) + + // ASSERT + assert.Equal(t, []flightInfo{{sequence: 0}, {sequence: 2}, {sequence: 3}, {sequence: 6}}, a) +} diff --git a/udp/congestion/none.go b/udp/congestion/none.go new file mode 100644 index 0000000..ded9f2b --- /dev/null +++ b/udp/congestion/none.go @@ -0,0 +1,26 @@ +package congestion + +import ( + "context" + "fmt" + "time" +) + +type None struct{} + +func NewNone() None { + return None{} +} + +func (c None) String() string { + return fmt.Sprintf("{None}") +} + +func (c None) ReceivedPacket(uint32, uint32, uint32) {} +func (c None) NextNack() uint32 { return 0 } +func (c None) NextAck() uint32 { return 0 } +func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) { + <-ctx.Done() + return 0, ctx.Err() +} +func (c None) Sequence(context.Context) (uint32, error) { return 0, nil } diff --git a/udp/flow.go b/udp/flow.go new file mode 100644 index 0000000..d552d93 --- /dev/null +++ b/udp/flow.go @@ -0,0 +1,306 @@ +package udp + +import ( + "context" + "fmt" + "log" + "mpbl3p/proxy" + "mpbl3p/shared" + "net" + "sync" + "time" +) + +type PacketWriter interface { + Write(b []byte) (int, error) + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + LocalAddr() net.Addr +} + +type PacketConn interface { + PacketWriter + ReadFromUDP(b []byte) (int, *net.UDPAddr, error) +} + +type InitiatedFlow struct { + Local func() string + Remote string + + g proxy.MacGenerator + keepalive time.Duration + + mu sync.RWMutex + Flow +} + +func (f *InitiatedFlow) String() string { + return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote) +} + +type Flow struct { + writer PacketWriter + raddr *net.UDPAddr + + isAlive bool + startup bool + congestion Congestion + + v proxy.MacVerifier + + inboundDatagrams chan []byte +} + +func (f Flow) String() string { + return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) +} + +func InitiateFlow( + local func() string, + remote string, + v proxy.MacVerifier, + g proxy.MacGenerator, + c Congestion, + keepalive time.Duration, +) (*InitiatedFlow, error) { + f := InitiatedFlow{ + Local: local, + Remote: remote, + Flow: newFlow(c, v), + g: g, + keepalive: keepalive, + } + + return &f, nil +} + +func newFlow(c Congestion, v proxy.MacVerifier) Flow { + return Flow{ + inboundDatagrams: make(chan []byte), + congestion: c, + v: v, + } +} + +func (f *InitiatedFlow) Reconnect(ctx context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.isAlive { + return nil + } + + localAddr, err := net.ResolveUDPAddr("udp", f.Local()) + if err != nil { + return err + } + + remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote) + if err != nil { + return err + } + + conn, err := net.DialUDP("udp", localAddr, remoteAddr) + if err != nil { + return err + } + + f.writer = conn + f.startup = true + + // prod the connection once a second until we get an ack, then consider it alive + go func() { + seq, err := f.congestion.Sequence(ctx) + if err != nil { + return + } + + for !f.isAlive { + if ctx.Err() != nil { + return + } + + p := Packet{ + ack: 0, + nack: 0, + seq: seq, + data: proxy.SimplePacket(nil), + } + + _ = f.sendPacket(p, f.g) + time.Sleep(1 * time.Second) + } + }() + + go func() { + _, _ = f.produceInternal(ctx, f.v, false) + }() + go f.earlyUpdateLoop(ctx, f.g, f.keepalive) + + if err := f.readQueuePacket(ctx, conn); err != nil { + return err + } + + f.isAlive = true + f.startup = false + + go func() { + lockedAccept := func() { + f.mu.RLock() + defer f.mu.RUnlock() + + if err := f.readQueuePacket(ctx, conn); err != nil { + log.Println(err) + } + } + + for f.isAlive { + log.Println("alive and listening for packets") + lockedAccept() + } + log.Println("no longer alive") + }() + + return nil +} + +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Consume(ctx, p, g) +} + +func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(ctx, v) +} + +func (f *Flow) IsAlive() bool { + return f.isAlive +} + +func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error { + if !f.isAlive { + return shared.ErrDeadConnection + } + + log.Println(f.congestion) + + // Sequence is the congestion controllers opportunity to block + log.Println("awaiting sequence") + seq, err := f.congestion.Sequence(ctx) + if err != nil { + return err + } + log.Println("received sequence") + + // Choose up to date ACK/NACK even after blocking + p := Packet{ + seq: seq, + data: pp, + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), + } + + return f.sendPacket(p, g) +} + +func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { + if !f.isAlive { + return nil, shared.ErrDeadConnection + } + + return f.produceInternal(ctx, v, true) +} + +func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { + for once := true; mustReturn || once; once = false { + log.Println(f.congestion) + + var received []byte + select { + case received = <-f.inboundDatagrams: + case <-ctx.Done(): + return nil, ctx.Err() + } + + b, err := proxy.StripMac(received, v) + if err != nil { + return nil, err + } + + p, err := UnmarshalPacket(b) + if err != nil { + return nil, err + } + + // adjust congestion control based on this packet's congestion header + f.congestion.ReceivedPacket(p.seq, p.nack, p.ack) + + // 12 bytes for header + the MAC + a timestamp + if len(b) == 12+f.v.CodeLength()+8 { + log.Println("handled keepalive/ack only packet") + continue + } + + return p, nil + } + + return nil, nil +} + +func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { + select { + case f.inboundDatagrams <- p: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { + b := p.Marshal() + b = proxy.AppendMac(b, g) + + if f.raddr == nil { + _, err := f.writer.Write(b) + return err + } else { + _, err := f.writer.WriteToUDP(b, f.raddr) + return err + } +} + +func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) { + for f.isAlive { + seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive) + if err != nil { + fmt.Printf("terminating earlyupdateloop for `%v`\n", f) + return + } + p := Packet{ + seq: seq, + data: proxy.SimplePacket(nil), + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), + } + + err = f.sendPacket(p, g) + if err != nil { + fmt.Printf("error sending early update packet: `%v`\n", err) + } + } +} + +func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { + // TODO: Replace 6000 with MTU+header size + buf := make([]byte, 6000) + n, _, err := c.ReadFromUDP(buf) + if err != nil { + return err + } + + return f.queueDatagram(ctx, buf[:n]) +} diff --git a/udp/flow_test.go b/udp/flow_test.go new file mode 100644 index 0000000..d044477 --- /dev/null +++ b/udp/flow_test.go @@ -0,0 +1,86 @@ +package udp + +import ( + "context" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "mpbl3p/mocks" + "mpbl3p/proxy" + "mpbl3p/udp/congestion" + "testing" + "time" +) + +func TestFlow_Consume(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := proxy.SimplePacket(testContent) + testMac := mocks.AlmostUselessMac{} + + t.Run("Length", func(t *testing.T) { + testConn := mocks.NewMockPerfectBiPacketConn(10) + + flowA := newFlow(congestion.NewNone(), testMac) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + err := flowA.Consume(context.Background(), testPacket, testMac) + require.Nil(t, err) + + buf := make([]byte, 100) + n, _, err := testConn.SideA().ReadFromUDP(buf) + require.Nil(t, err) + + // 12 header, 8 timestamp, 4 MAC + assert.Equal(t, len(testContent)+12+8+4, n) + }) +} + +func TestFlow_Produce(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 42, + nack: 26, + seq: 128, + data: proxy.SimplePacket(testContent), + } + testMac := mocks.AlmostUselessMac{} + + testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac) + + t.Run("Length", func(t *testing.T) { + done := make(chan struct{}) + + go func() { + testConn := mocks.NewMockPerfectBiPacketConn(10) + + _, err := testConn.SideA().Write(testMarshalled) + require.Nil(t, err) + + flowA := newFlow(congestion.NewNone(), testMac) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + go func() { + err := flowA.readQueuePacket(context.Background(), testConn.SideB()) + assert.Nil(t, err) + }() + p, err := flowA.Produce(context.Background(), testMac) + + require.Nil(t, err) + assert.Len(t, p.Contents(), len(testContent)) + + done <- struct{}{} + }() + + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-done: + case <-timer.C: + fmt.Println("timed out") + t.FailNow() + } + }) +} diff --git a/udp/listener.go b/udp/listener.go new file mode 100644 index 0000000..7c91679 --- /dev/null +++ b/udp/listener.go @@ -0,0 +1,104 @@ +package udp + +import ( + "context" + "log" + "mpbl3p/proxy" + "net" + "time" +) + +type ComparableUdpAddress struct { + IP [16]byte + Port int + Zone string +} + +func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { + var ip [16]byte + for i, b := range []byte(address.IP) { + ip[i] = b + } + + return ComparableUdpAddress{ + IP: ip, + Port: address.Port, + Zone: address.Zone, + } +} + +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error { + laddr, err := net.ResolveUDPAddr("udp", local) + if err != nil { + return err + } + + pconn, err := net.ListenUDP("udp", laddr) + if err != nil { + return err + } + + err = pconn.SetWriteBuffer(0) + if err != nil { + panic(err) + } + + receivedConnections := make(map[ComparableUdpAddress]*Flow) + + go func() { + for ctx.Err() == nil { + buf := make([]byte, 6000) + + if err := pconn.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + panic(err) + } + n, addr, err := pconn.ReadFromUDP(buf) + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + continue + } + panic(err) + } + + raddr := fromUdpAddress(*addr) + if f, exists := receivedConnections[raddr]; exists { + log.Println("existing flow. queuing...") + if err := f.queueDatagram(ctx, buf[:n]); err != nil { + + } + log.Println("queued") + continue + } + + v := v() + g := g() + + f := newFlow(c(), v) + + f.writer = pconn + f.raddr = addr + f.isAlive = true + + log.Printf("received new udp connection: %v\n", f) + + go f.earlyUpdateLoop(ctx, g, 0) + + receivedConnections[raddr] = &f + + if enableConsumers { + p.AddConsumer(ctx, &f, g) + } + if enableProducers { + p.AddProducer(ctx, &f, v) + } + + log.Println("handling...") + if err := f.queueDatagram(ctx, buf[:n]); err != nil { + return + } + log.Println("handled") + } + }() + + return nil +} diff --git a/udp/packet.go b/udp/packet.go new file mode 100644 index 0000000..08757d8 --- /dev/null +++ b/udp/packet.go @@ -0,0 +1,43 @@ +package udp + +import ( + "encoding/binary" + "mpbl3p/proxy" + "mpbl3p/shared" +) + +type Packet struct { + ack uint32 + nack uint32 + seq uint32 + + data proxy.Packet +} + +func UnmarshalPacket(b []byte) (p Packet, err error) { + if len(b) < 12 { + return Packet{}, shared.ErrNotEnoughBytes + } + + p.ack = binary.LittleEndian.Uint32(b[0:4]) + p.nack = binary.LittleEndian.Uint32(b[4:8]) + p.seq = binary.LittleEndian.Uint32(b[8:12]) + + p.data = proxy.SimplePacket(b[12:]) + return p, nil +} + +func (p Packet) Marshal() []byte { + data := p.data.Marshal() + header := make([]byte, 12) + + binary.LittleEndian.PutUint32(header[0:4], p.ack) + binary.LittleEndian.PutUint32(header[4:8], p.nack) + binary.LittleEndian.PutUint32(header[8:12], p.seq) + + return append(header, data...) +} + +func (p Packet) Contents() []byte { + return p.data.Contents() +} diff --git a/udp/packet_test.go b/udp/packet_test.go new file mode 100644 index 0000000..6ffc04c --- /dev/null +++ b/udp/packet_test.go @@ -0,0 +1,59 @@ +package udp + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "mpbl3p/proxy" + "testing" +) + +func TestPacket_Marshal(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 18, + nack: 29, + seq: 431, + data: proxy.SimplePacket(testContent), + } + + t.Run("Length", func(t *testing.T) { + marshalled := testPacket.Marshal() + + // 12 header + 8 timestamp + assert.Len(t, marshalled, len(testContent)+12) + }) +} + +func TestUnmarshalPacket(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 18, + nack: 29, + seq: 431, + data: proxy.SimplePacket(testContent), + } + testMarshalled := testPacket.Marshal() + + t.Run("Length", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + + require.Nil(t, err) + assert.Len(t, p.Contents(), len(testContent)) + }) + + t.Run("Contents", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + + require.Nil(t, err) + assert.Equal(t, p.Contents(), testContent) + }) + + t.Run("Header", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + require.Nil(t, err) + + assert.Equal(t, p.ack, uint32(18)) + assert.Equal(t, p.nack, uint32(29)) + assert.Equal(t, p.seq, uint32(431)) + }) +} diff --git a/udp/wireshark_dissector.lua b/udp/wireshark_dissector.lua new file mode 100644 index 0000000..6803f97 --- /dev/null +++ b/udp/wireshark_dissector.lua @@ -0,0 +1,36 @@ +mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP") + +ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement") +nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement") +seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number") +time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp") +proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data") + +mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F } + +function mpbl3p_udp.dissector(buffer, pinfo, tree) + if buffer:len() < 20 then + return + end + + pinfo.cols.protocol = "MPBL3P_UDP" + + local ack = buffer(0, 4):le_uint() + local nack = buffer(4, 4):le_uint() + local seq = buffer(8, 4):le_uint() + + local unix_time = buffer(buffer:len() - 8, 8):le_uint64() + + local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack) + + subtree:add(ack_F, ack) + subtree:add(nack_F, nack) + subtree:add(seq_F, seq) + subtree:add(time_F, NSTime.new(unix_time:tonumber())) + if buffer:len() > 20 then + subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8)) + end +end + +DissectorTable.get("udp.port"):add(4724, mpbl3p_udp) +DissectorTable.get("udp.port"):add(4725, mpbl3p_udp) diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index 2548b3b..0000000 --- a/utils/utils.go +++ /dev/null @@ -1,13 +0,0 @@ -package utils - -var NextId = make(chan int) - -func init() { - go func() { - i := 0 - for { - NextId <- i - i += 1 - } - }() -}