merge develop into master #21
34
.drone.yml
34
.drone.yml
@ -1,12 +1,18 @@
|
|||||||
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
type: docker
|
type: docker
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: format
|
||||||
|
image: golang:1.16
|
||||||
|
commands:
|
||||||
|
- bash -c "gofmt -l . | wc -l | cmp -s <(echo 0) || (gofmt -l . && exit 1)"
|
||||||
|
|
||||||
- name: install
|
- name: install
|
||||||
image: golang:1.15
|
image: golang:1.16
|
||||||
environment:
|
environment:
|
||||||
GOPROXY: http://10.20.0.25:3142|direct
|
GOPROXY: http://containers.internal.hillion.co.uk:3142,direct
|
||||||
volumes:
|
volumes:
|
||||||
- name: cache
|
- name: cache
|
||||||
path: /go
|
path: /go
|
||||||
@ -14,7 +20,7 @@ steps:
|
|||||||
- go test -i ./...
|
- go test -i ./...
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
image: golang:1.15
|
image: golang:1.16
|
||||||
volumes:
|
volumes:
|
||||||
- name: cache
|
- name: cache
|
||||||
path: /go
|
path: /go
|
||||||
@ -22,7 +28,7 @@ steps:
|
|||||||
- go test ./...
|
- go test ./...
|
||||||
|
|
||||||
- name: build (debian)
|
- name: build (debian)
|
||||||
image: golang:1.15-buster
|
image: golang:1.16-buster
|
||||||
when:
|
when:
|
||||||
event:
|
event:
|
||||||
- push
|
- push
|
||||||
@ -30,7 +36,10 @@ steps:
|
|||||||
- name: cache
|
- name: cache
|
||||||
path: /go
|
path: /go
|
||||||
commands:
|
commands:
|
||||||
- go build
|
- GOOS=linux GOARCH=amd64 go build -o linux_amd64
|
||||||
|
- GOOS=linux GOARCH=arm GOARM=7 go build -o linux_arm_v7
|
||||||
|
- GOOS=freebsd GOARCH=amd64 go build -o freebsd_amd64
|
||||||
|
- GOOS=freebsd GOARCH=arm64 go build -o freebsd_arm64_v8a
|
||||||
|
|
||||||
- name: upload
|
- name: upload
|
||||||
image: minio/mc
|
image: minio/mc
|
||||||
@ -42,9 +51,18 @@ steps:
|
|||||||
SECRET_KEY:
|
SECRET_KEY:
|
||||||
from_secret: s3_secret_key
|
from_secret: s3_secret_key
|
||||||
commands:
|
commands:
|
||||||
- mc alias set s3 http://10.20.0.25:3900 $${ACCESS_KEY} $${SECRET_KEY}
|
- mc alias set s3 https://s3.us-west-001.backblazeb2.com $${ACCESS_KEY} $${SECRET_KEY}
|
||||||
- mc cp mpbl3p s3/dissertation/binaries/debian/${DRONE_BRANCH}
|
- mc cp linux_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_amd64
|
||||||
|
- mc cp linux_arm_v7 s3/dissertation/binaries/debian/${DRONE_BRANCH}_linux_arm_v7
|
||||||
|
- mc cp freebsd_amd64 s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_amd64
|
||||||
|
- mc cp freebsd_arm64_v8a s3/dissertation/binaries/debian/${DRONE_BRANCH}_freebsd_arm64_v8a
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
- name: cache
|
- name: cache
|
||||||
temp: {}
|
temp: { }
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: signature
|
||||||
|
hmac: 7960420c7d02f9bce56d6429b612676d24cbe1d1608cf44a77da9afc411eccb8
|
||||||
|
|
||||||
|
...
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,4 @@
|
|||||||
config.ini
|
*.conf
|
||||||
logs/
|
logs/
|
||||||
|
|
||||||
# Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go
|
# Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go
|
||||||
|
129
README.md
129
README.md
@ -4,15 +4,130 @@
|
|||||||
### Linux
|
### Linux
|
||||||
#### Policy Based Routing
|
#### Policy Based Routing
|
||||||
|
|
||||||
ip route flush 11
|
ip route flush table 10
|
||||||
ip route add table 11 to 1.1.1.0/24 dev eth1
|
ip route add table 10 to 1.1.1.0/24 dev eth1
|
||||||
ip rule add from 1.1.1.4 table 11 priority 11
|
ip rule add from 1.1.1.4 table 10 priority 10
|
||||||
|
|
||||||
ip route flush 10
|
ip route flush table 11
|
||||||
ip route add table 10 to 1.1.1.0/24 dev eth2
|
ip route add table 11 to 1.1.1.0/24 dev eth2
|
||||||
ip rule add from 1.1.1.5 table 10 priority 10
|
ip rule add from 1.1.1.5 table 11 priority 11
|
||||||
|
|
||||||
#### ARP Flux
|
#### ARP Flux
|
||||||
|
|
||||||
sysctl -w net.ipv4.conf.all.arp_announce=1
|
sysctl -w net.ipv4.conf.all.arp_announce=1
|
||||||
sysctl -w net.ipv4.conf.all.arp_ignore=2
|
sysctl -w net.ipv4.conf.all.arp_ignore=1
|
||||||
|
|
||||||
|
See http://kb.linuxvirtualserver.org/wiki/Using_arp_announce/arp_ignore_to_disable_ARP
|
||||||
|
|
||||||
|
### Setup Scripts
|
||||||
|
These are functional setup scripts that make the application run as intended on Linux.
|
||||||
|
|
||||||
|
### Remote Portal
|
||||||
|
#### Pre-Start
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
## Set up variables
|
||||||
|
REMOTE_PORTAL_ADDRESS=A.B.C.D
|
||||||
|
|
||||||
|
## IPv4 Forwarding
|
||||||
|
sysctl -w net.ipv4.ip_forward=1
|
||||||
|
sysctl -w net.ipv4.conf.eth0.proxy_arp=1
|
||||||
|
|
||||||
|
## Transfer the local routing table to a much lower priority
|
||||||
|
(ip rule show | grep '20:') > /dev/null || ip rule add from all table local priority 20
|
||||||
|
ip rule del priority 0 2> /dev/null || true
|
||||||
|
|
||||||
|
## Ports to route locally
|
||||||
|
|
||||||
|
### MPBL3P
|
||||||
|
ip rule del priority 1 2> /dev/null || true
|
||||||
|
ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 1234 table local priority 1
|
||||||
|
|
||||||
|
### SSH
|
||||||
|
ip rule del priority 2 2> /dev/null || true
|
||||||
|
ip rule add to "$REMOTE_PORTAL_ADDRESS" dport 22 table local priority 2
|
||||||
|
|
||||||
|
#### Post-Start
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
## Set up variables
|
||||||
|
REMOTE_PORTAL_ADDRESS=A.B.C.D
|
||||||
|
|
||||||
|
## Tunnel addr/up
|
||||||
|
ip addr add 172.19.152.2/31 dev nc0
|
||||||
|
ip link set up nc0
|
||||||
|
|
||||||
|
# Route packets to the interface but not for nc via the tunnel
|
||||||
|
ip route flush table 19
|
||||||
|
ip route add table 19 to "$REMOTE_PORTAL_ADDRESS" via 172.19.152.3 dev nc0
|
||||||
|
ip rule del priority 19 2> /dev/null || true
|
||||||
|
ip rule add to "$REMOTE_PORTAL_ADDRESS" table 19 priority 19
|
||||||
|
|
||||||
|
### Local Portal
|
||||||
|
#### Pre-Start
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
## Set up variables
|
||||||
|
GATEWAY_INTERFACE=eth0
|
||||||
|
GATEWAY_ADDRESS=10.36.12.1
|
||||||
|
|
||||||
|
## Fix ARP
|
||||||
|
sysctl -w net.ipv4.conf.all.arp_announce=1
|
||||||
|
sysctl -w net.ipv4.conf.all.arp_ignore=1
|
||||||
|
|
||||||
|
## IPv4 Forwarding
|
||||||
|
sysctl -w net.ipv4.ip_forward=1
|
||||||
|
|
||||||
|
## Gateway Interface Setup
|
||||||
|
ip addr flush dev "$GATEWAY_INTERFACE"
|
||||||
|
ip addr add "$GATEWAY_ADDRESS"/32 dev "$GATEWAY_INTERFACE"
|
||||||
|
ip link set up "$GATEWAY_INTERFACE"
|
||||||
|
|
||||||
|
## Per-Interface Routing Tables
|
||||||
|
|
||||||
|
### 10.10.0.0/24
|
||||||
|
ip route flush table 10
|
||||||
|
ip route add table 10 default via 10.10.0.1
|
||||||
|
ip rule del priority 10 2> /dev/null || true
|
||||||
|
ip rule add from 10.10.0.0/24 table 10 priority 10
|
||||||
|
|
||||||
|
### 192.168.0.0/24
|
||||||
|
ip route flush table 11
|
||||||
|
ip route add table 11 default via 192.168.0.1
|
||||||
|
ip rule del priority 11 2> /dev/null || true
|
||||||
|
ip rule add from 192.168.0.0/24 table 11 priority 11
|
||||||
|
|
||||||
|
#### Post-Start
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
## Set up variables
|
||||||
|
REMOTE_PORTAL_ADDRESS=A.B.C.D
|
||||||
|
GATEWAY_INTERFACE=eth0
|
||||||
|
|
||||||
|
## Tunnel Address and Enable
|
||||||
|
ip addr add 172.19.152.3/31 dev nc0
|
||||||
|
ip link set up nc0
|
||||||
|
|
||||||
|
## Route Outbound Packets Correctly
|
||||||
|
ip route flush table 20
|
||||||
|
ip route add table 20 default via 172.19.152.2 dev nc0
|
||||||
|
ip rule del priority 20 2> /dev/null || true
|
||||||
|
ip rule add from "$REMOTE_PORTAL_ADDRESS" iif "$GATEWAY_INTERFACE" table 20 priority 20
|
||||||
|
|
||||||
|
## Route Inbound Packets Correctly
|
||||||
|
ip route flush table 21
|
||||||
|
ip route add table 21 to "$REMOTE_PORTAL_ADDRESS" dev "$GATEWAY_INTERFACE"
|
||||||
|
ip rule del priority 21 2> /dev/null || true
|
||||||
|
ip rule add to "$REMOTE_PORTAL_ADDRESS" table 21 priority 21
|
||||||
|
|
||||||
|
#### Client
|
||||||
|
|
||||||
|
Connect to `GATEWAY_INTERFACE` and set the IP to `REMOTE_PORTAL_ADDRESS`/32 with a gateway of `GATEWAY_ADDRESS`.
|
||||||
|
@ -1,48 +1,57 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"mpbl3p/crypto"
|
||||||
|
"mpbl3p/crypto/sharedkey"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"mpbl3p/tcp"
|
"mpbl3p/tcp"
|
||||||
"mpbl3p/tun"
|
"mpbl3p/udp"
|
||||||
|
"mpbl3p/udp/congestion"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Delete this code as soon as an alternative is available
|
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
|
||||||
type UselessMac struct{}
|
|
||||||
|
|
||||||
func (UselessMac) CodeLength() int {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (UselessMac) Generate([]byte) []byte {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u UselessMac) Verify([]byte, []byte) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Configuration) Build() (*proxy.Proxy, error) {
|
|
||||||
p := proxy.NewProxy(0)
|
p := proxy.NewProxy(0)
|
||||||
p.Generator = UselessMac{}
|
|
||||||
|
|
||||||
if c.Host.InterfaceName == "" {
|
var g func() proxy.MacGenerator
|
||||||
c.Host.InterfaceName = "nc%d"
|
var v func() proxy.MacVerifier
|
||||||
|
|
||||||
|
switch c.Host.Crypto {
|
||||||
|
case "None":
|
||||||
|
g = func() proxy.MacGenerator { return crypto.None{} }
|
||||||
|
v = func() proxy.MacVerifier { return crypto.None{} }
|
||||||
|
case "Blake2s":
|
||||||
|
key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := sharedkey.NewBlake2s(key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
g = func() proxy.MacGenerator {
|
||||||
|
g, _ := sharedkey.NewBlake2s(key)
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
v = func() proxy.MacVerifier {
|
||||||
|
v, _ := sharedkey.NewBlake2s(key)
|
||||||
|
return v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ss, err := tun.NewTun(c.Host.InterfaceName, 1500)
|
p.Source = source
|
||||||
if err != nil {
|
p.Sink = sink
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Source = ss
|
|
||||||
p.Sink = ss
|
|
||||||
|
|
||||||
for _, peer := range c.Peers {
|
for _, peer := range c.Peers {
|
||||||
switch peer.Method {
|
switch peer.Method {
|
||||||
case "TCP":
|
case "TCP":
|
||||||
err := buildTcp(p, peer)
|
if err := buildTcp(ctx, p, peer, g, v); err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
|
}
|
||||||
|
case "UDP":
|
||||||
|
if err := buildUdp(ctx, p, peer, g, v); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,20 +60,82 @@ func (c Configuration) Build() (*proxy.Proxy, error) {
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTcp(p *proxy.Proxy, peer Peer) error {
|
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||||
if peer.RemoteHost != "" {
|
var laddr func() string
|
||||||
f, err := tcp.InitiateFlow(
|
if peer.LocalPort == 0 {
|
||||||
fmt.Sprintf("%s:", peer.LocalHost),
|
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||||
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
} else {
|
||||||
)
|
laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) }
|
||||||
|
|
||||||
p.AddConsumer(f)
|
|
||||||
p.AddProducer(f, UselessMac{})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{})
|
if peer.RemoteHost != "" {
|
||||||
|
f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peer.DisableConsumer {
|
||||||
|
p.AddConsumer(ctx, f, g())
|
||||||
|
}
|
||||||
|
if !peer.DisableProducer {
|
||||||
|
p.AddProducer(ctx, f, v())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||||
|
var laddr func() string
|
||||||
|
if peer.LocalPort == 0 {
|
||||||
|
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||||
|
} else {
|
||||||
|
laddr = func() string { return fmt.Sprintf("%s:%d", peer.GetLocalHost(), peer.LocalPort) }
|
||||||
|
}
|
||||||
|
|
||||||
|
var c func() udp.Congestion
|
||||||
|
switch peer.Congestion {
|
||||||
|
case "None":
|
||||||
|
c = func() udp.Congestion { return congestion.NewNone() }
|
||||||
|
default:
|
||||||
|
fallthrough
|
||||||
|
case "NewReno":
|
||||||
|
c = func() udp.Congestion { return congestion.NewNewReno() }
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.RemoteHost != "" {
|
||||||
|
f, err := udp.InitiateFlow(
|
||||||
|
laddr,
|
||||||
|
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
|
||||||
|
v(),
|
||||||
|
g(),
|
||||||
|
c(),
|
||||||
|
time.Duration(peer.KeepAlive)*time.Second,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peer.DisableConsumer {
|
||||||
|
p.AddConsumer(ctx, f, g())
|
||||||
|
}
|
||||||
|
if !peer.DisableProducer {
|
||||||
|
p.AddProducer(ctx, f, v())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,32 +1,96 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import "github.com/go-playground/validator/v10"
|
import (
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
var v = validator.New()
|
var v = validator.New()
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if err := v.RegisterValidation("iface", func(fl validator.FieldLevel) bool {
|
||||||
|
name, ok := fl.Field().Interface().(string)
|
||||||
|
if ok && name != "" {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error getting interfaces: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range ifaces {
|
||||||
|
if i.Name == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Configuration struct {
|
type Configuration struct {
|
||||||
Host Host
|
Host Host
|
||||||
Peers []Peer `validate:"dive"`
|
Peers []Peer `validate:"dive"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Host struct {
|
type Host struct {
|
||||||
PrivateKey string `validate:"required"`
|
Crypto string `validate:"required,oneof=None Blake2s"`
|
||||||
InterfaceName string
|
SharedKey string `validate:"required_if=Crypto Blake2s"`
|
||||||
|
MTU uint `validate:"required,min=576"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
PublicKey string `validate:"required"`
|
Method string `validate:"oneof=TCP UDP"`
|
||||||
Method string `validate:"oneof=TCP"`
|
|
||||||
|
|
||||||
LocalHost string `validate:"omitempty,ip"`
|
LocalHost string `validate:"omitempty,ip|iface"`
|
||||||
LocalPort uint `validate:"max=65535"`
|
LocalPort uint `validate:"max=65535"`
|
||||||
|
|
||||||
RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"`
|
RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"`
|
||||||
RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"`
|
RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"`
|
||||||
|
|
||||||
|
Congestion string `validate:"required_unless=Method TCP,omitempty,oneof=NewReno None"`
|
||||||
|
|
||||||
KeepAlive uint
|
KeepAlive uint
|
||||||
Timeout uint
|
Timeout uint
|
||||||
RetryWait uint
|
RetryWait uint
|
||||||
|
|
||||||
|
DisableConsumer bool `validate:"omitempty,nefield=DisableProducer"`
|
||||||
|
DisableProducer bool `validate:"omitempty,nefield=DisableConsumer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Peer) GetLocalHost() string {
|
||||||
|
if p.LocalHost == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Var(p.LocalHost, "ip"); err == nil {
|
||||||
|
return p.LocalHost
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := net.InterfaceByName(p.LocalHost)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
addr := addrs[0].String()
|
||||||
|
addr = strings.Split(addr, "/")[0]
|
||||||
|
log.Printf("resolved interface `%s` to `%v`", p.LocalHost, addr)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "invalid"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Configuration) Validate() error {
|
func (c Configuration) Validate() error {
|
||||||
|
15
crypto/none.go
Normal file
15
crypto/none.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
type None struct{}
|
||||||
|
|
||||||
|
func (None) CodeLength() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (None) Generate([]byte) []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (None) Verify([]byte, []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
40
crypto/sharedkey/blake2s.go
Normal file
40
crypto/sharedkey/blake2s.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package sharedkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Blake2s struct {
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBlake2s(key []byte) (*Blake2s, error) {
|
||||||
|
_, err := blake2s.New128(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Blake2s{key: key}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b Blake2s) CodeLength() int {
|
||||||
|
return blake2s.Size128
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b Blake2s) Generate(d []byte) []byte {
|
||||||
|
h, _ := blake2s.New128(b.key)
|
||||||
|
h.Write(d)
|
||||||
|
return h.Sum([]byte{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b Blake2s) Verify(d []byte, s []byte) error {
|
||||||
|
h, _ := blake2s.New128(b.key)
|
||||||
|
h.Write(d)
|
||||||
|
sum := h.Sum([]byte{})
|
||||||
|
if !bytes.Equal(sum, s) {
|
||||||
|
return shared.ErrBadChecksum
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
45
crypto/sharedkey/blake2s_test.go
Normal file
45
crypto/sharedkey/blake2s_test.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package sharedkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"math/rand"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBlake2s_GenerateVerify(t *testing.T) {
|
||||||
|
t.Run("GeneratedVerifies", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
key := make([]byte, 16)
|
||||||
|
rand.Read(key)
|
||||||
|
buf := make([]byte, 500)
|
||||||
|
rand.Read(buf)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
b := Blake2s{key}
|
||||||
|
code := b.Generate(buf)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
err := b.Verify(buf, code)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FlippedBitFailsVerify", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
key := make([]byte, 16)
|
||||||
|
rand.Read(key)
|
||||||
|
buf := make([]byte, 500)
|
||||||
|
rand.Read(buf)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
b := Blake2s{key}
|
||||||
|
code := b.Generate(buf)
|
||||||
|
|
||||||
|
offset := rand.Intn(len(buf) * 8)
|
||||||
|
buf[offset/8] ^= 1 << (offset % 8)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
err := b.Verify(buf, code)
|
||||||
|
assert.Equal(t, shared.ErrBadChecksum, err)
|
||||||
|
})
|
||||||
|
}
|
41
flags/flags.go
Normal file
41
flags/flags.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package flags
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
goflags "github.com/jessevdk/go-flags"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
var PrintedHelpErr = goflags.ErrHelp
|
||||||
|
var NotEnoughArgs = errors.New("not enough arguments")
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`
|
||||||
|
ConfigFile string `short:"c" long:"config" description:"Configuration file location" value-name:"FILE"`
|
||||||
|
PidFile string `short:"p" long:"pid" description:"PID file location"`
|
||||||
|
|
||||||
|
Positional struct {
|
||||||
|
InterfaceName string `required:"yes" positional-arg-name:"INTERFACE-NAME" description:"Interface name"`
|
||||||
|
} `positional-args:"yes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseFlags() (*Options, error) {
|
||||||
|
o := new(Options)
|
||||||
|
parser := goflags.NewParser(o, goflags.Default)
|
||||||
|
|
||||||
|
_, err := parser.Parse()
|
||||||
|
if err != nil {
|
||||||
|
parser.WriteHelp(os.Stdout)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.ConfigFile == "" {
|
||||||
|
o.ConfigFile = fmt.Sprintf(DefaultConfigFile, o.Positional.InterfaceName)
|
||||||
|
}
|
||||||
|
if o.PidFile == "" {
|
||||||
|
o.PidFile = fmt.Sprintf(DefaultPidFile, o.Positional.InterfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return o, nil
|
||||||
|
}
|
4
flags/locs_freebsd.go
Normal file
4
flags/locs_freebsd.go
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package flags
|
||||||
|
|
||||||
|
const DefaultConfigFile = "/usr/local/etc/netcombiner/%s"
|
||||||
|
const DefaultPidFile = "/var/run/netcombiner/%s.pid"
|
4
flags/locs_linux.go
Normal file
4
flags/locs_linux.go
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package flags
|
||||||
|
|
||||||
|
const DefaultConfigFile = "/etc/netcombiner/%s"
|
||||||
|
const DefaultPidFile = "/var/run/netcombiner/%s.pid"
|
8
go.mod
8
go.mod
@ -3,10 +3,12 @@ module mpbl3p
|
|||||||
go 1.15
|
go 1.15
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-playground/assert/v2 v2.0.1
|
github.com/go-playground/validator/v10 v10.6.0
|
||||||
github.com/go-playground/validator/v10 v10.4.1
|
github.com/jessevdk/go-flags v1.5.0
|
||||||
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab
|
|
||||||
github.com/smartystreets/goconvey v1.6.4 // indirect
|
github.com/smartystreets/goconvey v1.6.4 // indirect
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.4.0
|
||||||
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
|
||||||
|
golang.org/x/net v0.0.0-20210326060303-6b1517762897 // indirect
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58
|
||||||
gopkg.in/ini.v1 v1.62.0
|
gopkg.in/ini.v1 v1.62.0
|
||||||
)
|
)
|
||||||
|
28
go.sum
28
go.sum
@ -8,14 +8,16 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87
|
|||||||
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
||||||
github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
|
github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
|
||||||
github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4=
|
github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4=
|
||||||
|
github.com/go-playground/validator/v10 v10.6.0 h1:UGIt4xR++fD9QrBOoo/ascJfGe3AGHEB9s6COnss4Rk=
|
||||||
|
github.com/go-playground/validator/v10 v10.6.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
|
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
|
||||||
|
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
|
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
|
||||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||||
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab h1:dAXDRtXYxj4sTR5WeRuTFJGH18QMT6AUpUgRwedI6es=
|
|
||||||
github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab/go.mod h1:N5a/Ll2ZNk5wjiLNW9LIiNtO9RNYcaYmcXSYKMYrlDg=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
|
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
|
||||||
@ -26,17 +28,35 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
|
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
|
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||||
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
|
||||||
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20210326060303-6b1517762897 h1:KrsHThm5nFk34YtATK1LsThyGhGbGe1olrte/HInHvs=
|
||||||
|
golang.org/x/net v0.0.0-20210326060303-6b1517762897/go.mod h1:uSPa2vr4CLtc/ILN5odXGNXS6mhrKVzTaCXzk9m6W3k=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210324051608-47abb6519492 h1:Paq34FxTluEPvVyayQqMPgHm+vTOrIifmcYxFBx9TLg=
|
||||||
|
golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58 h1:HiPOx0boQr3qv0HkZ4fGLtTXJ5tmkbv0d8UmkNcxdv0=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20201118132417-da19db415a58/go.mod h1:Dz+cq5bnrai9EpgYj4GDof/+qaGzbRWbeaAOs1bUYa0=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU=
|
gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU=
|
||||||
|
127
main.go
127
main.go
@ -1,36 +1,155 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/config"
|
"mpbl3p/config"
|
||||||
|
"mpbl3p/flags"
|
||||||
|
"mpbl3p/tun"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ENV_NC_TUN_FD = "NC_TUN_FD"
|
||||||
|
ENV_NC_CONFIG_PATH = "NC_CONFIG_PATH"
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
log.SetFlags(log.Ldate | log.Ltime | log.Llongfile)
|
||||||
|
|
||||||
|
if _, exists := os.LookupEnv(ENV_NC_TUN_FD); !exists {
|
||||||
|
// we are the parent process
|
||||||
|
// 1) process arguments
|
||||||
|
// 2) validate config
|
||||||
|
// 2) create a tun adapter
|
||||||
|
// 3) spawn a child
|
||||||
|
// 4) exit
|
||||||
|
|
||||||
|
o, err := flags.ParseFlags()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, flags.PrintedHelpErr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("loading config...")
|
||||||
|
|
||||||
|
c, err := config.LoadConfig(o.ConfigFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("error validating config: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("creating tun adapter...")
|
||||||
|
t, err := tun.NewTun(o.Positional.InterfaceName, int(c.Host.MTU))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.Foreground {
|
||||||
|
if err := os.Setenv(ENV_NC_TUN_FD, fmt.Sprintf("%d", t.File().Fd())); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := os.Setenv(ENV_NC_CONFIG_PATH, o.ConfigFile); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Println("switching to foreground")
|
||||||
|
goto FOREGROUND
|
||||||
|
}
|
||||||
|
|
||||||
|
files := make([]*os.File, 4)
|
||||||
|
files[0], _ = os.Open(os.DevNull) // stdin
|
||||||
|
files[1], _ = os.Open(os.DevNull) // stderr
|
||||||
|
files[2], _ = os.Open(os.DevNull) // stdout
|
||||||
|
files[3] = t.File()
|
||||||
|
|
||||||
|
env := os.Environ()
|
||||||
|
env = append(env, fmt.Sprintf("%s=3", ENV_NC_TUN_FD))
|
||||||
|
env = append(env, fmt.Sprintf("%s=%s", ENV_NC_CONFIG_PATH, o.ConfigFile))
|
||||||
|
|
||||||
|
attr := os.ProcAttr{
|
||||||
|
Env: env,
|
||||||
|
Files: files,
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
process, err := os.StartProcess(
|
||||||
|
path,
|
||||||
|
os.Args,
|
||||||
|
&attr,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pidFile, err := os.Create(o.PidFile)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(pidFile, "%d", process.Pid); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = process.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// we are the child process
|
||||||
|
// 1) recreate tun adapter from file descriptor
|
||||||
|
// 2) launch proxy
|
||||||
|
|
||||||
|
FOREGROUND:
|
||||||
|
|
||||||
log.Println("loading config...")
|
log.Println("loading config...")
|
||||||
|
|
||||||
c, err := config.LoadConfig("config.ini")
|
c, err := config.LoadConfig(os.Getenv(ENV_NC_CONFIG_PATH))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Println("connecting tun adapter...")
|
||||||
|
tunFd, err := strconv.ParseUint(os.Getenv(ENV_NC_TUN_FD), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := tun.NewFromFile(uintptr(tunFd), int(c.Host.MTU))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := t.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
log.Println("building config...")
|
log.Println("building config...")
|
||||||
p, err := c.Build()
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p, err := c.Build(ctx, t, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("starting...")
|
log.Println("starting proxy...")
|
||||||
p.Start()
|
p.Start()
|
||||||
|
|
||||||
log.Println("running")
|
log.Println("proxy started")
|
||||||
|
|
||||||
signals := make(chan os.Signal)
|
signals := make(chan os.Signal)
|
||||||
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
|
||||||
<-signals
|
<-signals
|
||||||
|
log.Println("exiting...")
|
||||||
}
|
}
|
||||||
|
@ -1,67 +0,0 @@
|
|||||||
package mocks
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
type MockPerfectBiConn struct {
|
|
||||||
directionA chan byte
|
|
||||||
directionB chan byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn {
|
|
||||||
return MockPerfectBiConn{
|
|
||||||
directionA: make(chan byte, bufSize),
|
|
||||||
directionB: make(chan byte, bufSize),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc MockPerfectBiConn) SideA() MockPerfectConn {
|
|
||||||
return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc MockPerfectBiConn) SideB() MockPerfectConn {
|
|
||||||
return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA}
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockPerfectConn struct {
|
|
||||||
inbound chan byte
|
|
||||||
outbound chan byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) SetWriteDeadline(time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) Read(p []byte) (n int, err error) {
|
|
||||||
for i := range p {
|
|
||||||
if i == 0 {
|
|
||||||
p[i] = <-c.inbound
|
|
||||||
} else {
|
|
||||||
select {
|
|
||||||
case b := <-c.inbound:
|
|
||||||
p[i] = b
|
|
||||||
default:
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) Write(p []byte) (n int, err error) {
|
|
||||||
for _, b := range p {
|
|
||||||
c.outbound <- b
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) {
|
|
||||||
for i := range p {
|
|
||||||
select {
|
|
||||||
case b := <-c.inbound:
|
|
||||||
p[i] = b
|
|
||||||
default:
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
@ -1,6 +1,8 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import "mpbl3p/shared"
|
import (
|
||||||
|
"mpbl3p/shared"
|
||||||
|
)
|
||||||
|
|
||||||
type AlmostUselessMac struct{}
|
type AlmostUselessMac struct{}
|
||||||
|
|
||||||
|
62
mocks/packetconn.go
Normal file
62
mocks/packetconn.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package mocks
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type MockPerfectBiPacketConn struct {
|
||||||
|
directionA chan []byte
|
||||||
|
directionB chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn {
|
||||||
|
return MockPerfectBiPacketConn{
|
||||||
|
directionA: make(chan []byte, bufSize),
|
||||||
|
directionB: make(chan []byte, bufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn {
|
||||||
|
return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn {
|
||||||
|
return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPerfectPacketConn struct {
|
||||||
|
inbound chan []byte
|
||||||
|
outbound chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) Write(b []byte) (int, error) {
|
||||||
|
c.outbound <- b
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
c.outbound <- b
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) LocalAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||||
|
p := <-c.inbound
|
||||||
|
return copy(b, p), &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
return copy(p, b), nil
|
||||||
|
default:
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
95
mocks/streamconn.go
Normal file
95
mocks/streamconn.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockPerfectBiStreamConn struct {
|
||||||
|
directionA chan byte
|
||||||
|
directionB chan byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn {
|
||||||
|
return MockPerfectBiStreamConn{
|
||||||
|
directionA: make(chan byte, bufSize),
|
||||||
|
directionB: make(chan byte, bufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn {
|
||||||
|
return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn {
|
||||||
|
return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPerfectStreamConn struct {
|
||||||
|
inbound chan byte
|
||||||
|
outbound chan byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn interface {
|
||||||
|
Read(b []byte) (n int, err error)
|
||||||
|
Write(b []byte) (n int, err error)
|
||||||
|
SetWriteDeadline(time.Time) error
|
||||||
|
|
||||||
|
// For printing
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) {
|
||||||
|
for i := range p {
|
||||||
|
if i == 0 {
|
||||||
|
p[i] = <-c.inbound
|
||||||
|
} else {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
p[i] = b
|
||||||
|
default:
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) {
|
||||||
|
for _, b := range p {
|
||||||
|
c.outbound <- b
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only used for printing flow information
|
||||||
|
func (c MockPerfectStreamConn) LocalAddr() net.Addr {
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 499,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) {
|
||||||
|
for i := range p {
|
||||||
|
select {
|
||||||
|
case b := <-c.inbound:
|
||||||
|
p[i] = b
|
||||||
|
default:
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
@ -5,56 +5,42 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Packet struct {
|
type Packet interface {
|
||||||
Data []byte
|
Marshal() []byte
|
||||||
timestamp time.Time
|
Contents() []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a packet from the raw data of an IP packet
|
type SimplePacket []byte
|
||||||
func NewPacket(data []byte) Packet {
|
|
||||||
return Packet{
|
|
||||||
Data: data,
|
|
||||||
timestamp: time.Now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// rebuild a packet from the wrapped format
|
|
||||||
func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) {
|
|
||||||
// the MAC is the last N bytes
|
|
||||||
data := raw[:len(raw)-verifier.CodeLength()]
|
|
||||||
sum := raw[len(raw)-verifier.CodeLength():]
|
|
||||||
|
|
||||||
if err := verifier.Verify(data, sum); err != nil {
|
|
||||||
return Packet{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Packet{
|
|
||||||
Data: data[:len(data)-8],
|
|
||||||
}
|
|
||||||
|
|
||||||
unixTime := int64(binary.LittleEndian.Uint64(data[len(data)-8:]))
|
|
||||||
p.timestamp = time.Unix(unixTime, 0)
|
|
||||||
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the raw data of the IP packet
|
// get the raw data of the IP packet
|
||||||
func (p Packet) Raw() []byte {
|
func (p SimplePacket) Marshal() []byte {
|
||||||
return p.Data
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// produce the wrapped format of a packet
|
func (p SimplePacket) Contents() []byte {
|
||||||
func (p Packet) Marshal(generator MacGenerator) []byte {
|
return p
|
||||||
// length of data + length of timestamp (8 byte) + length of checksum
|
}
|
||||||
slice := make([]byte, len(p.Data)+8+generator.CodeLength())
|
|
||||||
|
func AppendMac(b []byte, g MacGenerator) []byte {
|
||||||
copy(slice, p.Data)
|
footer := make([]byte, 8)
|
||||||
|
unixTime := uint64(time.Now().Unix())
|
||||||
unixTime := uint64(p.timestamp.Unix())
|
binary.LittleEndian.PutUint64(footer, unixTime)
|
||||||
binary.LittleEndian.PutUint64(slice[len(p.Data):], unixTime)
|
|
||||||
|
b = append(b, footer...)
|
||||||
mac := generator.Generate(slice)
|
|
||||||
copy(slice[len(p.Data)+8:], mac)
|
mac := g.Generate(b)
|
||||||
|
return append(b, mac...)
|
||||||
return slice
|
}
|
||||||
|
|
||||||
|
func StripMac(b []byte, v MacVerifier) ([]byte, error) {
|
||||||
|
data := b[:len(b)-v.CodeLength()]
|
||||||
|
sum := b[len(b)-v.CodeLength():]
|
||||||
|
|
||||||
|
if err := v.Verify(data, sum); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Verify timestamp
|
||||||
|
|
||||||
|
return data[:len(data)-8], nil
|
||||||
}
|
}
|
||||||
|
@ -4,31 +4,59 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"mpbl3p/mocks"
|
"mpbl3p/mocks"
|
||||||
|
"mpbl3p/shared"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPacket_Marshal(t *testing.T) {
|
func TestAppendMac(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := NewPacket(testContent)
|
|
||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
testPacket := SimplePacket(testContent)
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
appended := AppendMac(testMarshalled, testMac)
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
marshalled := testPacket.Marshal(testMac)
|
assert.Len(t, appended, len(testMarshalled)+8+4)
|
||||||
|
})
|
||||||
|
|
||||||
assert.Len(t, marshalled, len(testContent)+8+4)
|
t.Run("Mac", func(t *testing.T) {
|
||||||
|
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Original", func(t *testing.T) {
|
||||||
|
assert.Equal(t, testMarshalled, appended[:len(testMarshalled)])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPacket(t *testing.T) {
|
func TestStripMac(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := NewPacket(testContent)
|
|
||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
testMarshalled := testPacket.Marshal(testMac)
|
testPacket := SimplePacket(testContent)
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
appended := AppendMac(testMarshalled, testMac)
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
p, err := UnmarshalPacket(testMarshalled, testMac)
|
cut, err := StripMac(appended, testMac)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, p.Raw(), len(testContent))
|
assert.Len(t, cut, len(testMarshalled))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IncorrectMac", func(t *testing.T) {
|
||||||
|
badMac := make([]byte, len(testMarshalled)+4)
|
||||||
|
copy(badMac, testMarshalled)
|
||||||
|
copy(badMac[:len(testMarshalled)], "dcba")
|
||||||
|
_, err := StripMac(badMac, testMac)
|
||||||
|
|
||||||
|
assert.Error(t, err, shared.ErrBadChecksum)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Original", func(t *testing.T) {
|
||||||
|
cut, err := StripMac(appended, testMac)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, testMarshalled, cut)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,24 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Producer interface {
|
type Producer interface {
|
||||||
IsAlive() bool
|
IsAlive() bool
|
||||||
Produce(MacVerifier) (Packet, error)
|
Produce(context.Context, MacVerifier) (Packet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Consumer interface {
|
type Consumer interface {
|
||||||
IsAlive() bool
|
IsAlive() bool
|
||||||
Consume(Packet, MacGenerator) error
|
Consume(context.Context, Packet, MacGenerator) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reconnectable interface {
|
type Reconnectable interface {
|
||||||
Reconnect() error
|
Reconnect(context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Source interface {
|
type Source interface {
|
||||||
@ -31,8 +33,6 @@ type Proxy struct {
|
|||||||
Source Source
|
Source Source
|
||||||
Sink Sink
|
Sink Sink
|
||||||
|
|
||||||
Generator MacGenerator
|
|
||||||
|
|
||||||
proxyChan chan Packet
|
proxyChan chan Packet
|
||||||
sinkChan chan Packet
|
sinkChan chan Packet
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ func (p Proxy) Start() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Proxy) AddConsumer(c Consumer) {
|
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
|
||||||
go func() {
|
go func() {
|
||||||
_, reconnectable := c.(Reconnectable)
|
_, reconnectable := c.(Reconnectable)
|
||||||
|
|
||||||
@ -75,28 +75,42 @@ func (p Proxy) AddConsumer(c Consumer) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect `%v`\n", c)
|
log.Printf("attempting to connect consumer `%v`\n", c)
|
||||||
err = c.(Reconnectable).Reconnect()
|
err = c.(Reconnectable).Reconnect(ctx)
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("connected `%v`\n", c)
|
log.Printf("connected consumer `%v`\n", c)
|
||||||
}
|
}
|
||||||
|
|
||||||
for c.IsAlive() {
|
for c.IsAlive() {
|
||||||
if err := c.Consume(<-p.proxyChan, p.Generator); err != nil {
|
select {
|
||||||
log.Println(err)
|
case <-ctx.Done():
|
||||||
break
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
case packet := <-p.proxyChan:
|
||||||
|
if err := c.Consume(ctx, packet, g); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println(err)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("closed connection `%v`\n", c)
|
log.Printf("closed consumer `%v`\n", c)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
|
||||||
go func() {
|
go func() {
|
||||||
_, reconnectable := pr.(Reconnectable)
|
_, reconnectable := pr.(Reconnectable)
|
||||||
|
|
||||||
@ -104,25 +118,42 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect `%v`\n", pr)
|
log.Printf("attempting to connect producer `%v`\n", pr)
|
||||||
err = pr.(Reconnectable).Reconnect()
|
err = pr.(Reconnectable).Reconnect(ctx)
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
log.Printf("connected `%v`\n", pr)
|
log.Printf("connected producer `%v`\n", pr)
|
||||||
}
|
}
|
||||||
|
|
||||||
for pr.IsAlive() {
|
for pr.IsAlive() {
|
||||||
if packet, err := pr.Produce(v); err != nil {
|
if packet, err := pr.Produce(ctx, v); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
p.sinkChan <- packet
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
case p.sinkChan <- packet:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("closed connection `%v`\n", pr)
|
log.Printf("closed producer `%v`\n", pr)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -4,3 +4,4 @@ import "errors"
|
|||||||
|
|
||||||
var ErrBadChecksum = errors.New("the packet had a bad checksum")
|
var ErrBadChecksum = errors.New("the packet had a bad checksum")
|
||||||
var ErrDeadConnection = errors.New("the connection is dead")
|
var ErrDeadConnection = errors.New("the connection is dead")
|
||||||
|
var ErrNotEnoughBytes = errors.New("not enough bytes")
|
||||||
|
198
tcp/flow.go
198
tcp/flow.go
@ -1,8 +1,10 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"mpbl3p/shared"
|
"mpbl3p/shared"
|
||||||
@ -11,16 +13,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNotEnoughBytes = errors.New("not enough bytes")
|
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Read(b []byte) (n int, err error)
|
Read(b []byte) (n int, err error)
|
||||||
Write(b []byte) (n int, err error)
|
Write(b []byte) (n int, err error)
|
||||||
SetWriteDeadline(time.Time) error
|
SetWriteDeadline(time.Time) error
|
||||||
|
|
||||||
|
// For printing
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
RemoteAddr() net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type InitiatedFlow struct {
|
type InitiatedFlow struct {
|
||||||
Local string
|
Local func() string
|
||||||
Remote string
|
Remote string
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@ -28,21 +32,64 @@ type InitiatedFlow struct {
|
|||||||
Flow
|
Flow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) String() string {
|
||||||
|
return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local(), f.Remote)
|
||||||
|
}
|
||||||
|
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
conn Conn
|
conn Conn
|
||||||
isAlive bool
|
isAlive bool
|
||||||
|
|
||||||
|
toConsume, produced chan []byte
|
||||||
|
consumeErrors, produceErrors chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitiateFlow(local, remote string) (*InitiatedFlow, error) {
|
func NewFlow() Flow {
|
||||||
|
return Flow{
|
||||||
|
toConsume: make(chan []byte),
|
||||||
|
produced: make(chan []byte),
|
||||||
|
consumeErrors: make(chan error),
|
||||||
|
produceErrors: make(chan error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlowConn(ctx context.Context, conn Conn) Flow {
|
||||||
|
f := Flow{
|
||||||
|
conn: conn,
|
||||||
|
isAlive: true,
|
||||||
|
|
||||||
|
toConsume: make(chan []byte),
|
||||||
|
produced: make(chan []byte),
|
||||||
|
consumeErrors: make(chan error),
|
||||||
|
produceErrors: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
go f.produceMarshalled(ctx)
|
||||||
|
go f.consumeMarshalled(ctx)
|
||||||
|
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flow) String() string {
|
||||||
|
return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) IsAlive() bool {
|
||||||
|
return f.isAlive
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) {
|
||||||
f := InitiatedFlow{
|
f := InitiatedFlow{
|
||||||
Local: local,
|
Local: local,
|
||||||
Remote: remote,
|
Remote: remote,
|
||||||
|
|
||||||
|
Flow: NewFlow(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return &f, nil
|
return &f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Reconnect() error {
|
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
@ -50,7 +97,7 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
localAddr, err := net.ResolveTCPAddr("tcp", f.Local)
|
localAddr, err := net.ResolveTCPAddr("tcp", f.Local())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -65,93 +112,128 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.SetWriteBuffer(0)
|
if err := conn.SetWriteBuffer(0); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
f.conn = conn
|
f.conn = conn
|
||||||
f.isAlive = true
|
f.isAlive = true
|
||||||
|
|
||||||
|
go f.produceMarshalled(ctx)
|
||||||
|
go f.consumeMarshalled(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) IsAlive() bool {
|
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
return f.isAlive
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
return f.Flow.Consume(p, g)
|
return f.Flow.Consume(ctx, p, g)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) {
|
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(ctx, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return shared.ErrDeadConnection
|
return shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
data := p.Marshal(g)
|
select {
|
||||||
err = f.consumeMarshalled(data)
|
case err := <-f.consumeErrors:
|
||||||
if err != nil {
|
|
||||||
f.isAlive = false
|
f.isAlive = false
|
||||||
|
return err
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Flow) consumeMarshalled(data []byte) error {
|
marshalled := p.Marshal()
|
||||||
|
data := proxy.AppendMac(marshalled, g)
|
||||||
|
|
||||||
prefixedData := make([]byte, len(data)+4)
|
prefixedData := make([]byte, len(data)+4)
|
||||||
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
|
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
|
||||||
copy(prefixedData[4:], data)
|
copy(prefixedData[4:], data)
|
||||||
|
|
||||||
err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
select {
|
||||||
if err != nil {
|
case f.toConsume <- prefixedData:
|
||||||
return err
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
_, err = f.conn.Write(prefixedData)
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
f.mu.RLock()
|
|
||||||
defer f.mu.RUnlock()
|
|
||||||
|
|
||||||
return f.Flow.Produce(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return proxy.Packet{}, shared.ErrDeadConnection
|
return nil, shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := f.produceMarshalled()
|
var data []byte
|
||||||
if err != nil {
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case data = <-f.produced:
|
||||||
|
case err := <-f.produceErrors:
|
||||||
f.isAlive = false
|
f.isAlive = false
|
||||||
return proxy.Packet{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy.UnmarshalPacket(data, v)
|
b, err := proxy.StripMac(data, v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxy.SimplePacket(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) produceMarshalled() ([]byte, error) {
|
func (f *Flow) consumeMarshalled(ctx context.Context) {
|
||||||
lengthBytes := make([]byte, 4)
|
for {
|
||||||
if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil {
|
data := <-f.toConsume
|
||||||
return nil, err
|
|
||||||
} else if n != 4 {
|
|
||||||
return nil, ErrNotEnoughBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
length := binary.LittleEndian.Uint32(lengthBytes)
|
err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||||
dataBytes := make([]byte, length)
|
if err != nil {
|
||||||
|
f.consumeErrors <- err
|
||||||
var read uint32
|
return
|
||||||
for read < length {
|
}
|
||||||
if n, err := io.LimitReader(f.conn, int64(length-read)).Read(dataBytes[read:]); err != nil {
|
_, err = f.conn.Write(data)
|
||||||
return nil, err
|
if err != nil {
|
||||||
} else {
|
f.consumeErrors <- err
|
||||||
read += uint32(n)
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return dataBytes, nil
|
|
||||||
|
func (f *Flow) produceMarshalled(ctx context.Context) {
|
||||||
|
buf := bufio.NewReader(f.conn)
|
||||||
|
|
||||||
|
for {
|
||||||
|
lengthBytes := make([]byte, 4)
|
||||||
|
if n, err := io.LimitReader(buf, 4).Read(lengthBytes); err != nil {
|
||||||
|
f.produceErrors <- err
|
||||||
|
return
|
||||||
|
} else if n != 4 {
|
||||||
|
f.produceErrors <- shared.ErrNotEnoughBytes
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
length := binary.LittleEndian.Uint32(lengthBytes)
|
||||||
|
dataBytes := make([]byte, length)
|
||||||
|
|
||||||
|
var read uint32
|
||||||
|
for read < length {
|
||||||
|
if n, err := io.LimitReader(buf, int64(length-read)).Read(dataBytes[read:]); err != nil {
|
||||||
|
f.produceErrors <- err
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
read += uint32(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.produced <- dataBytes
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"github.com/go-playground/assert/v2"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"mpbl3p/mocks"
|
"mpbl3p/mocks"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
@ -11,15 +12,15 @@ import (
|
|||||||
|
|
||||||
func TestFlow_Consume(t *testing.T) {
|
func TestFlow_Consume(t *testing.T) {
|
||||||
testContent := []byte("A test string is the content of this packet.")
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
testPacket := proxy.NewPacket(testContent)
|
testPacket := proxy.SimplePacket(testContent)
|
||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := NewFlowConn(context.Background(), testConn.SideA())
|
||||||
|
|
||||||
err := flowA.Consume(testPacket, testMac)
|
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
@ -39,28 +40,28 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
testMac := mocks.AlmostUselessMac{}
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
t.Run("Length", func(t *testing.T) {
|
t.Run("Length", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := NewFlowConn(context.Background(), testConn.SideA())
|
||||||
|
|
||||||
_, err := testConn.SideB().Write(testMarshalled)
|
_, err := testConn.SideB().Write(testMarshalled)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, len(testContent), len(p.Raw()))
|
assert.Equal(t, len(testContent), len(p.Contents()))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Value", func(t *testing.T) {
|
t.Run("Value", func(t *testing.T) {
|
||||||
testConn := mocks.NewMockPerfectBiConn(100)
|
testConn := mocks.NewMockPerfectBiStreamConn(100)
|
||||||
|
|
||||||
flowA := Flow{conn: testConn.SideA(), isAlive: true}
|
flowA := NewFlowConn(context.Background(), testConn.SideA())
|
||||||
|
|
||||||
_, err := testConn.SideB().Write(testMarshalled)
|
_, err := testConn.SideB().Write(testMarshalled)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, testContent, string(p.Raw()))
|
assert.Equal(t, testContent, string(p.Contents()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error {
|
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error {
|
||||||
laddr, err := net.ResolveTCPAddr("tcp", local)
|
laddr, err := net.ResolveTCPAddr("tcp", local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -24,17 +25,20 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.SetWriteBuffer(0)
|
if err := conn.SetWriteBuffer(0); err != nil {
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f := Flow{conn: conn, isAlive: true}
|
f := NewFlowConn(ctx, conn)
|
||||||
|
|
||||||
log.Printf("received new connection: %v\n", f)
|
log.Printf("received new tcp connection: %v\n", f)
|
||||||
|
|
||||||
p.AddConsumer(&f)
|
if enableConsumers {
|
||||||
p.AddProducer(&f, v)
|
p.AddConsumer(ctx, &f, g())
|
||||||
|
}
|
||||||
|
if enableProducers {
|
||||||
|
p.AddProducer(ctx, &f, v())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
72
tun/tun.go
72
tun/tun.go
@ -1,85 +1,65 @@
|
|||||||
package tun
|
package tun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/pkg/taptun"
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SourceSink struct {
|
type SourceSink struct {
|
||||||
tun *taptun.Tun
|
tun wgtun.Device
|
||||||
bufferSize int
|
|
||||||
|
|
||||||
up bool
|
|
||||||
upMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTun(namingScheme string, bufferSize int) (ss *SourceSink, err error) {
|
func NewTun(name string, mtu int) (t wgtun.Device, err error) {
|
||||||
ss = &SourceSink{}
|
return wgtun.CreateTUN(name, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
ss.tun, err = taptun.NewTun(namingScheme)
|
func NewFromFile(fd uintptr, mtu int) (ss *SourceSink, err error) {
|
||||||
|
ss = new(SourceSink)
|
||||||
|
|
||||||
|
file := os.NewFile(fd, "")
|
||||||
|
ss.tun, err = wgtun.CreateTUNFromFile(file, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ss.bufferSize = bufferSize
|
|
||||||
|
|
||||||
ss.upMu.Lock()
|
|
||||||
go func() {
|
|
||||||
defer ss.upMu.Unlock()
|
|
||||||
|
|
||||||
for {
|
|
||||||
iface, err := net.InterfaceByName(ss.tun.String())
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if strings.Contains(iface.Flags.String(), "up") {
|
|
||||||
log.Println("tun is up")
|
|
||||||
ss.up = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *SourceSink) Close() error {
|
||||||
|
return t.tun.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *SourceSink) Source() (proxy.Packet, error) {
|
func (t *SourceSink) Source() (proxy.Packet, error) {
|
||||||
if !t.up {
|
mtu, err := t.tun.MTU()
|
||||||
t.upMu.Lock()
|
if err != nil {
|
||||||
t.upMu.Unlock()
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := make([]byte, t.bufferSize)
|
buf := make([]byte, mtu+4)
|
||||||
|
|
||||||
read, err := t.tun.Read(buf)
|
read, err := t.tun.Read(buf, 4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return proxy.Packet{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if read == 0 {
|
if read == 0 {
|
||||||
return proxy.Packet{}, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy.NewPacket(buf[:read]), nil
|
return proxy.SimplePacket(buf[4 : read+4]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var good, bad float64
|
var good, bad float64
|
||||||
|
|
||||||
func (t *SourceSink) Sink(packet proxy.Packet) error {
|
func (t *SourceSink) Sink(packet proxy.Packet) error {
|
||||||
if !t.up {
|
// make space for tun header
|
||||||
t.upMu.Lock()
|
content := make([]byte, len(packet.Contents())+4)
|
||||||
t.upMu.Unlock()
|
copy(content[4:], packet.Contents())
|
||||||
}
|
|
||||||
|
|
||||||
_, err := t.tun.Write(packet.Raw())
|
_, err := t.tun.Write(content, 4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case *os.PathError:
|
case *os.PathError:
|
||||||
|
16
udp/congestion.go
Normal file
16
udp/congestion.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Congestion interface {
|
||||||
|
Sequence(ctx context.Context) (uint32, error)
|
||||||
|
NextAck() uint32
|
||||||
|
NextNack() uint32
|
||||||
|
|
||||||
|
ReceivedPacket(seq, nack, ack uint32)
|
||||||
|
|
||||||
|
AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
|
||||||
|
}
|
284
udp/congestion/newreno.go
Normal file
284
udp/congestion/newreno.go
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const RttExponentialFactor = 0.1
|
||||||
|
const RttLossDelay = 1.5
|
||||||
|
|
||||||
|
type NewReno struct {
|
||||||
|
sequence chan uint32
|
||||||
|
|
||||||
|
inFlight []flightInfo
|
||||||
|
lastSent time.Time
|
||||||
|
inFlightMu sync.Mutex
|
||||||
|
|
||||||
|
awaitingAck sortableFlights
|
||||||
|
ack, nack uint32
|
||||||
|
lastAck, lastNack uint32
|
||||||
|
ackNackMu sync.Mutex
|
||||||
|
|
||||||
|
rttNanos float64
|
||||||
|
windowSize, windowCount uint32
|
||||||
|
slowStart bool
|
||||||
|
windowNotifier chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type flightInfo struct {
|
||||||
|
time time.Time
|
||||||
|
sequence uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type sortableFlights []flightInfo
|
||||||
|
|
||||||
|
func (f sortableFlights) Len() int {
|
||||||
|
return len(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f sortableFlights) Swap(i, j int) {
|
||||||
|
f[i], f[j] = f[j], f[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f sortableFlights) Less(i, j int) bool {
|
||||||
|
return f[i].sequence < f[j].sequence
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) String() string {
|
||||||
|
return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, len(c.inFlight), c.lastAck, c.lastNack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNewReno() *NewReno {
|
||||||
|
c := NewReno{
|
||||||
|
sequence: make(chan uint32),
|
||||||
|
windowNotifier: make(chan struct{}),
|
||||||
|
|
||||||
|
windowSize: 1,
|
||||||
|
rttNanos: float64((10 * time.Millisecond).Nanoseconds()),
|
||||||
|
slowStart: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var s uint32
|
||||||
|
for {
|
||||||
|
if s == 0 {
|
||||||
|
s++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sequence <- s
|
||||||
|
s++
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
|
||||||
|
// decide what acks and nacks to send
|
||||||
|
if seq != 0 {
|
||||||
|
c.receivedSequence(seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decide how window size was affected
|
||||||
|
if nack != 0 {
|
||||||
|
c.receivedNack(nack)
|
||||||
|
}
|
||||||
|
if ack != 0 {
|
||||||
|
c.receivedAck(ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
|
||||||
|
for len(c.inFlight) >= int(c.windowSize) {
|
||||||
|
<-c.windowNotifier
|
||||||
|
}
|
||||||
|
|
||||||
|
c.inFlightMu.Lock()
|
||||||
|
defer c.inFlightMu.Unlock()
|
||||||
|
|
||||||
|
var s uint32
|
||||||
|
select {
|
||||||
|
case s = <-c.sequence:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
t := time.Now()
|
||||||
|
|
||||||
|
c.inFlight = append(c.inFlight, flightInfo{
|
||||||
|
time: t,
|
||||||
|
sequence: s,
|
||||||
|
})
|
||||||
|
c.lastSent = t
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) NextAck() uint32 {
|
||||||
|
a := c.ack
|
||||||
|
atomic.StoreUint32(&c.lastAck, a)
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) NextNack() uint32 {
|
||||||
|
n := c.nack
|
||||||
|
atomic.StoreUint32(&c.lastNack, n)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
|
||||||
|
for {
|
||||||
|
rtt := time.Duration(math.Round(c.rttNanos))
|
||||||
|
time.Sleep(rtt / 2)
|
||||||
|
|
||||||
|
c.checkNack()
|
||||||
|
|
||||||
|
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
|
||||||
|
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
|
||||||
|
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
|
||||||
|
return 0, nil // no ack needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// CASE 2: No message sent within the keepalive time
|
||||||
|
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
|
||||||
|
return c.Sequence(ctx) // require an ack
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) receivedSequence(seq uint32) {
|
||||||
|
c.ackNackMu.Lock()
|
||||||
|
defer c.ackNackMu.Unlock()
|
||||||
|
|
||||||
|
if seq < c.nack || seq < c.ack {
|
||||||
|
// packet received out of order has already been cumulatively NACKed
|
||||||
|
// or duplicate packet received and already ACKed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if seq != c.ack+1 && seq != c.nack+1 {
|
||||||
|
c.awaitingAck = append(c.awaitingAck, flightInfo{
|
||||||
|
time: time.Now(),
|
||||||
|
sequence: seq,
|
||||||
|
})
|
||||||
|
return // if this seq doesn't change the ack field, updateAck will still not do anything useful
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(c.awaitingAck)
|
||||||
|
c.updateAck(seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) checkNack() {
|
||||||
|
c.ackNackMu.Lock()
|
||||||
|
defer c.ackNackMu.Unlock()
|
||||||
|
|
||||||
|
if len(c.awaitingAck) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(c.awaitingAck)
|
||||||
|
|
||||||
|
lossThreshold := time.Duration(c.rttNanos * RttLossDelay)
|
||||||
|
if c.awaitingAck[0].time.Before(time.Now().Add(-lossThreshold)) {
|
||||||
|
// if the next packet sequence to ack was received more than an rttlossdelay ago
|
||||||
|
// mark the packet(s) blocking it as missing with a nack
|
||||||
|
// then update ack from the delayed packet
|
||||||
|
c.nack = c.awaitingAck[0].sequence - 1
|
||||||
|
c.updateAck(c.nack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) updateAck(start uint32) {
|
||||||
|
a := start
|
||||||
|
|
||||||
|
var removed int
|
||||||
|
for _, e := range c.awaitingAck {
|
||||||
|
if e.sequence == a+1 {
|
||||||
|
a = e.sequence
|
||||||
|
removed++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ack = a
|
||||||
|
c.awaitingAck = c.awaitingAck[removed:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) receivedNack(nack uint32) {
|
||||||
|
c.ackNackMu.Lock()
|
||||||
|
defer c.ackNackMu.Unlock()
|
||||||
|
|
||||||
|
c.inFlightMu.Lock()
|
||||||
|
defer c.inFlightMu.Unlock()
|
||||||
|
|
||||||
|
// as both ack and nack are cumulative, inflight will always be ordered by seq
|
||||||
|
var i int
|
||||||
|
for i < len(c.inFlight) && c.inFlight[i].sequence <= nack {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.slowStart = false
|
||||||
|
c.inFlight = c.inFlight[i:]
|
||||||
|
|
||||||
|
for {
|
||||||
|
s := c.windowSize
|
||||||
|
if s == 1 || atomic.CompareAndSwapUint32(&c.windowSize, s, s/2) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.windowNotifier <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) receivedAck(ack uint32) {
|
||||||
|
c.ackNackMu.Lock()
|
||||||
|
defer c.ackNackMu.Unlock()
|
||||||
|
|
||||||
|
c.inFlightMu.Lock()
|
||||||
|
defer c.inFlightMu.Unlock()
|
||||||
|
|
||||||
|
// as both ack and nack are cumulative, inflight will always be ordered by seq
|
||||||
|
var i int
|
||||||
|
for i < len(c.inFlight) && c.inFlight[i].sequence <= ack {
|
||||||
|
rtt := float64(time.Now().Sub(c.inFlight[i].time).Nanoseconds())
|
||||||
|
c.rttNanos = c.rttNanos*(1-RttExponentialFactor) + rtt*RttExponentialFactor
|
||||||
|
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.inFlight = c.inFlight[i:]
|
||||||
|
if c.slowStart {
|
||||||
|
atomic.AddUint32(&c.windowSize, uint32(i))
|
||||||
|
} else {
|
||||||
|
c.windowCount += uint32(i)
|
||||||
|
s := c.windowSize
|
||||||
|
if c.windowCount > s {
|
||||||
|
c.windowCount -= s
|
||||||
|
atomic.AddUint32(&c.windowSize, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.windowNotifier <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
369
udp/congestion/newreno_test.go
Normal file
369
udp/congestion/newreno_test.go
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type congestionPacket struct {
|
||||||
|
seq, nack, ack uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type newRenoTest struct {
|
||||||
|
sideA, sideB *NewReno
|
||||||
|
|
||||||
|
aOutbound, bOutbound chan congestionPacket
|
||||||
|
aInbound, bInbound chan congestionPacket
|
||||||
|
|
||||||
|
halfRtt time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNewRenoTest(rtt time.Duration) *newRenoTest {
|
||||||
|
return &newRenoTest{
|
||||||
|
sideA: NewNewReno(),
|
||||||
|
sideB: NewNewReno(),
|
||||||
|
|
||||||
|
aOutbound: make(chan congestionPacket),
|
||||||
|
bOutbound: make(chan congestionPacket),
|
||||||
|
|
||||||
|
aInbound: make(chan congestionPacket),
|
||||||
|
bInbound: make(chan congestionPacket),
|
||||||
|
|
||||||
|
halfRtt: rtt / 2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *newRenoTest) Start(ctx context.Context) {
|
||||||
|
type packetWithTime struct {
|
||||||
|
t time.Time
|
||||||
|
p congestionPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
aOutboundDelayed := make(chan packetWithTime, 128)
|
||||||
|
bOutboundDelayed := make(chan packetWithTime, 128)
|
||||||
|
|
||||||
|
delayer := func(tp chan packetWithTime, cp chan congestionPacket) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case p := <-tp:
|
||||||
|
s := p.t.Add(n.halfRtt).Sub(time.Now())
|
||||||
|
time.Sleep(s)
|
||||||
|
cp <- p.p
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go delayer(aOutboundDelayed, n.bInbound)
|
||||||
|
go delayer(bOutboundDelayed, n.aInbound)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case p := <-n.aOutbound:
|
||||||
|
aOutboundDelayed <- packetWithTime{t: time.Now(), p: p}
|
||||||
|
case p := <-n.bOutbound:
|
||||||
|
bOutboundDelayed <- packetWithTime{t: time.Now(), p: p}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *newRenoTest) RunSideA(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case p := <-n.aInbound:
|
||||||
|
n.sideA.ReceivedPacket(p.seq, p.nack, p.ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if seq != 0 {
|
||||||
|
// skip keepalive
|
||||||
|
// required to ensure AwaitEarlyUpdate terminates
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p := congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: n.sideA.NextNack(),
|
||||||
|
ack: n.sideA.NextAck(),
|
||||||
|
}
|
||||||
|
n.aOutbound <- p
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *newRenoTest) RunSideB(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case p := <-n.bInbound:
|
||||||
|
n.sideB.ReceivedPacket(p.seq, p.nack, p.ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if seq != 0 {
|
||||||
|
// skip keepalive
|
||||||
|
// required to ensure AwaitEarlyUpdate terminates
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p := congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: n.sideB.NextNack(),
|
||||||
|
ack: n.sideB.NextAck(),
|
||||||
|
}
|
||||||
|
n.bOutbound <- p
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewReno_Congestion(t *testing.T) {
|
||||||
|
t.Run("OneWay", func(t *testing.T) {
|
||||||
|
t.Run("Lossless", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
rtt := 80 * time.Millisecond
|
||||||
|
numPackets := 50
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c := newNewRenoTest(rtt)
|
||||||
|
c.Start(ctx)
|
||||||
|
c.RunSideA(ctx)
|
||||||
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
// sleep to simulate preparing packet
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
|
c.aOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideA.NextNack(),
|
||||||
|
ack: c.sideA.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow the systems to catch up before asserting
|
||||||
|
time.Sleep(rtt + 30*time.Millisecond)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
|
assert.Equal(t, uint32(0), c.sideA.ack)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), c.sideB.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SequenceLoss", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
rtt := 80 * time.Millisecond
|
||||||
|
numPackets := 50
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c := newNewRenoTest(rtt)
|
||||||
|
c.Start(ctx)
|
||||||
|
c.RunSideA(ctx)
|
||||||
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
// sleep to simulate preparing packet
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
|
if seq == 20 {
|
||||||
|
// Simulate packet loss of sequence 20
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.aOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideA.NextNack(),
|
||||||
|
ack: c.sideA.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(rtt + 30*time.Millisecond)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
|
assert.Equal(t, uint32(0), c.sideA.ack)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(20), c.sideB.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TwoWay", func(t *testing.T) {
|
||||||
|
t.Run("Lossless", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
rtt := 80 * time.Millisecond
|
||||||
|
numPackets := 50
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c := newNewRenoTest(rtt)
|
||||||
|
c.Start(ctx)
|
||||||
|
c.RunSideA(ctx)
|
||||||
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
|
c.aOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideA.NextNack(),
|
||||||
|
ack: c.sideA.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
|
c.bOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideB.NextNack(),
|
||||||
|
ack: c.sideB.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
|
||||||
|
time.Sleep(rtt + 30*time.Millisecond)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideA.ack)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), c.sideB.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SequenceLoss", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
rtt := 80 * time.Millisecond
|
||||||
|
numPackets := 100
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c := newNewRenoTest(rtt)
|
||||||
|
c.Start(ctx)
|
||||||
|
c.RunSideA(ctx)
|
||||||
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
|
if seq == 9 {
|
||||||
|
// Simulate packet loss of sequence 9
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.aOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideA.NextNack(),
|
||||||
|
ack: c.sideA.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < numPackets; i++ {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
|
if seq == 13 {
|
||||||
|
// Simulate packet loss of sequence 13
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.bOutbound <- congestionPacket{
|
||||||
|
seq: seq,
|
||||||
|
nack: c.sideB.NextNack(),
|
||||||
|
ack: c.sideB.NextAck(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
|
||||||
|
time.Sleep(rtt + 30*time.Millisecond)
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(13), c.sideA.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideA.ack)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(9), c.sideB.nack)
|
||||||
|
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortableFlights_Less(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
a := []flightInfo{{sequence: 0}, {sequence: 6}, {sequence: 3}, {sequence: 2}}
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
sort.Sort(sortableFlights(a))
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
assert.Equal(t, []flightInfo{{sequence: 0}, {sequence: 2}, {sequence: 3}, {sequence: 6}}, a)
|
||||||
|
}
|
26
udp/congestion/none.go
Normal file
26
udp/congestion/none.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type None struct{}
|
||||||
|
|
||||||
|
func NewNone() None {
|
||||||
|
return None{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c None) String() string {
|
||||||
|
return fmt.Sprintf("{None}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c None) ReceivedPacket(uint32, uint32, uint32) {}
|
||||||
|
func (c None) NextNack() uint32 { return 0 }
|
||||||
|
func (c None) NextAck() uint32 { return 0 }
|
||||||
|
func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
|
||||||
|
<-ctx.Done()
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }
|
306
udp/flow.go
Normal file
306
udp/flow.go
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketWriter interface {
|
||||||
|
Write(b []byte) (int, error)
|
||||||
|
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketConn interface {
|
||||||
|
PacketWriter
|
||||||
|
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InitiatedFlow struct {
|
||||||
|
Local func() string
|
||||||
|
Remote string
|
||||||
|
|
||||||
|
g proxy.MacGenerator
|
||||||
|
keepalive time.Duration
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
Flow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) String() string {
|
||||||
|
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Flow struct {
|
||||||
|
writer PacketWriter
|
||||||
|
raddr *net.UDPAddr
|
||||||
|
|
||||||
|
isAlive bool
|
||||||
|
startup bool
|
||||||
|
congestion Congestion
|
||||||
|
|
||||||
|
v proxy.MacVerifier
|
||||||
|
|
||||||
|
inboundDatagrams chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flow) String() string {
|
||||||
|
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitiateFlow(
|
||||||
|
local func() string,
|
||||||
|
remote string,
|
||||||
|
v proxy.MacVerifier,
|
||||||
|
g proxy.MacGenerator,
|
||||||
|
c Congestion,
|
||||||
|
keepalive time.Duration,
|
||||||
|
) (*InitiatedFlow, error) {
|
||||||
|
f := InitiatedFlow{
|
||||||
|
Local: local,
|
||||||
|
Remote: remote,
|
||||||
|
Flow: newFlow(c, v),
|
||||||
|
g: g,
|
||||||
|
keepalive: keepalive,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFlow(c Congestion, v proxy.MacVerifier) Flow {
|
||||||
|
return Flow{
|
||||||
|
inboundDatagrams: make(chan []byte),
|
||||||
|
congestion: c,
|
||||||
|
v: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
if f.isAlive {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.writer = conn
|
||||||
|
f.startup = true
|
||||||
|
|
||||||
|
// prod the connection once a second until we get an ack, then consider it alive
|
||||||
|
go func() {
|
||||||
|
seq, err := f.congestion.Sequence(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for !f.isAlive {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p := Packet{
|
||||||
|
ack: 0,
|
||||||
|
nack: 0,
|
||||||
|
seq: seq,
|
||||||
|
data: proxy.SimplePacket(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = f.sendPacket(p, f.g)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = f.produceInternal(ctx, f.v, false)
|
||||||
|
}()
|
||||||
|
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
|
||||||
|
|
||||||
|
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.isAlive = true
|
||||||
|
f.startup = false
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
lockedAccept := func() {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for f.isAlive {
|
||||||
|
log.Println("alive and listening for packets")
|
||||||
|
lockedAccept()
|
||||||
|
}
|
||||||
|
log.Println("no longer alive")
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Consume(ctx, p, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(ctx, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) IsAlive() bool {
|
||||||
|
return f.isAlive
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
|
||||||
|
if !f.isAlive {
|
||||||
|
return shared.ErrDeadConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(f.congestion)
|
||||||
|
|
||||||
|
// Sequence is the congestion controllers opportunity to block
|
||||||
|
log.Println("awaiting sequence")
|
||||||
|
seq, err := f.congestion.Sequence(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("received sequence")
|
||||||
|
|
||||||
|
// Choose up to date ACK/NACK even after blocking
|
||||||
|
p := Packet{
|
||||||
|
seq: seq,
|
||||||
|
data: pp,
|
||||||
|
ack: f.congestion.NextAck(),
|
||||||
|
nack: f.congestion.NextNack(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.sendPacket(p, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
|
if !f.isAlive {
|
||||||
|
return nil, shared.ErrDeadConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.produceInternal(ctx, v, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
||||||
|
for once := true; mustReturn || once; once = false {
|
||||||
|
log.Println(f.congestion)
|
||||||
|
|
||||||
|
var received []byte
|
||||||
|
select {
|
||||||
|
case received = <-f.inboundDatagrams:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := proxy.StripMac(received, v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := UnmarshalPacket(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// adjust congestion control based on this packet's congestion header
|
||||||
|
f.congestion.ReceivedPacket(p.seq, p.nack, p.ack)
|
||||||
|
|
||||||
|
// 12 bytes for header + the MAC + a timestamp
|
||||||
|
if len(b) == 12+f.v.CodeLength()+8 {
|
||||||
|
log.Println("handled keepalive/ack only packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
|
||||||
|
select {
|
||||||
|
case f.inboundDatagrams <- p:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
||||||
|
b := p.Marshal()
|
||||||
|
b = proxy.AppendMac(b, g)
|
||||||
|
|
||||||
|
if f.raddr == nil {
|
||||||
|
_, err := f.writer.Write(b)
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
_, err := f.writer.WriteToUDP(b, f.raddr)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
|
||||||
|
for f.isAlive {
|
||||||
|
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := Packet{
|
||||||
|
seq: seq,
|
||||||
|
data: proxy.SimplePacket(nil),
|
||||||
|
ack: f.congestion.NextAck(),
|
||||||
|
nack: f.congestion.NextNack(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = f.sendPacket(p, g)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error sending early update packet: `%v`\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
|
||||||
|
// TODO: Replace 6000 with MTU+header size
|
||||||
|
buf := make([]byte, 6000)
|
||||||
|
n, _, err := c.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.queueDatagram(ctx, buf[:n])
|
||||||
|
}
|
86
udp/flow_test.go
Normal file
86
udp/flow_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"mpbl3p/mocks"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"mpbl3p/udp/congestion"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlow_Consume(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := proxy.SimplePacket(testContent)
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||||
|
|
||||||
|
flowA := newFlow(congestion.NewNone(), testMac)
|
||||||
|
|
||||||
|
flowA.writer = testConn.SideB()
|
||||||
|
flowA.isAlive = true
|
||||||
|
|
||||||
|
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
n, _, err := testConn.SideA().ReadFromUDP(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// 12 header, 8 timestamp, 4 MAC
|
||||||
|
assert.Equal(t, len(testContent)+12+8+4, n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlow_Produce(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 42,
|
||||||
|
nack: 26,
|
||||||
|
seq: 128,
|
||||||
|
data: proxy.SimplePacket(testContent),
|
||||||
|
}
|
||||||
|
testMac := mocks.AlmostUselessMac{}
|
||||||
|
|
||||||
|
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
testConn := mocks.NewMockPerfectBiPacketConn(10)
|
||||||
|
|
||||||
|
_, err := testConn.SideA().Write(testMarshalled)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
flowA := newFlow(congestion.NewNone(), testMac)
|
||||||
|
|
||||||
|
flowA.writer = testConn.SideB()
|
||||||
|
flowA.isAlive = true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}()
|
||||||
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
timer := time.NewTimer(500 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-timer.C:
|
||||||
|
fmt.Println("timed out")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
104
udp/listener.go
Normal file
104
udp/listener.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ComparableUdpAddress struct {
|
||||||
|
IP [16]byte
|
||||||
|
Port int
|
||||||
|
Zone string
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
||||||
|
var ip [16]byte
|
||||||
|
for i, b := range []byte(address.IP) {
|
||||||
|
ip[i] = b
|
||||||
|
}
|
||||||
|
|
||||||
|
return ComparableUdpAddress{
|
||||||
|
IP: ip,
|
||||||
|
Port: address.Port,
|
||||||
|
Zone: address.Zone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error {
|
||||||
|
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pconn, err := net.ListenUDP("udp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pconn.SetWriteBuffer(0)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedConnections := make(map[ComparableUdpAddress]*Flow)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
buf := make([]byte, 6000)
|
||||||
|
|
||||||
|
if err := pconn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
n, addr, err := pconn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raddr := fromUdpAddress(*addr)
|
||||||
|
if f, exists := receivedConnections[raddr]; exists {
|
||||||
|
log.Println("existing flow. queuing...")
|
||||||
|
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
||||||
|
|
||||||
|
}
|
||||||
|
log.Println("queued")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
v := v()
|
||||||
|
g := g()
|
||||||
|
|
||||||
|
f := newFlow(c(), v)
|
||||||
|
|
||||||
|
f.writer = pconn
|
||||||
|
f.raddr = addr
|
||||||
|
f.isAlive = true
|
||||||
|
|
||||||
|
log.Printf("received new udp connection: %v\n", f)
|
||||||
|
|
||||||
|
go f.earlyUpdateLoop(ctx, g, 0)
|
||||||
|
|
||||||
|
receivedConnections[raddr] = &f
|
||||||
|
|
||||||
|
if enableConsumers {
|
||||||
|
p.AddConsumer(ctx, &f, g)
|
||||||
|
}
|
||||||
|
if enableProducers {
|
||||||
|
p.AddProducer(ctx, &f, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("handling...")
|
||||||
|
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("handled")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
43
udp/packet.go
Normal file
43
udp/packet.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
ack uint32
|
||||||
|
nack uint32
|
||||||
|
seq uint32
|
||||||
|
|
||||||
|
data proxy.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalPacket(b []byte) (p Packet, err error) {
|
||||||
|
if len(b) < 12 {
|
||||||
|
return Packet{}, shared.ErrNotEnoughBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ack = binary.LittleEndian.Uint32(b[0:4])
|
||||||
|
p.nack = binary.LittleEndian.Uint32(b[4:8])
|
||||||
|
p.seq = binary.LittleEndian.Uint32(b[8:12])
|
||||||
|
|
||||||
|
p.data = proxy.SimplePacket(b[12:])
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Packet) Marshal() []byte {
|
||||||
|
data := p.data.Marshal()
|
||||||
|
header := make([]byte, 12)
|
||||||
|
|
||||||
|
binary.LittleEndian.PutUint32(header[0:4], p.ack)
|
||||||
|
binary.LittleEndian.PutUint32(header[4:8], p.nack)
|
||||||
|
binary.LittleEndian.PutUint32(header[8:12], p.seq)
|
||||||
|
|
||||||
|
return append(header, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Packet) Contents() []byte {
|
||||||
|
return p.data.Contents()
|
||||||
|
}
|
59
udp/packet_test.go
Normal file
59
udp/packet_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPacket_Marshal(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 18,
|
||||||
|
nack: 29,
|
||||||
|
seq: 431,
|
||||||
|
data: proxy.SimplePacket(testContent),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
marshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
// 12 header + 8 timestamp
|
||||||
|
assert.Len(t, marshalled, len(testContent)+12)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalPacket(t *testing.T) {
|
||||||
|
testContent := []byte("A test string is the content of this packet.")
|
||||||
|
testPacket := Packet{
|
||||||
|
ack: 18,
|
||||||
|
nack: 29,
|
||||||
|
seq: 431,
|
||||||
|
data: proxy.SimplePacket(testContent),
|
||||||
|
}
|
||||||
|
testMarshalled := testPacket.Marshal()
|
||||||
|
|
||||||
|
t.Run("Length", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Contents", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, p.Contents(), testContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Header", func(t *testing.T) {
|
||||||
|
p, err := UnmarshalPacket(testMarshalled)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, p.ack, uint32(18))
|
||||||
|
assert.Equal(t, p.nack, uint32(29))
|
||||||
|
assert.Equal(t, p.seq, uint32(431))
|
||||||
|
})
|
||||||
|
}
|
36
udp/wireshark_dissector.lua
Normal file
36
udp/wireshark_dissector.lua
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP")
|
||||||
|
|
||||||
|
ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement")
|
||||||
|
nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement")
|
||||||
|
seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number")
|
||||||
|
time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp")
|
||||||
|
proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data")
|
||||||
|
|
||||||
|
mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F }
|
||||||
|
|
||||||
|
function mpbl3p_udp.dissector(buffer, pinfo, tree)
|
||||||
|
if buffer:len() < 20 then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
pinfo.cols.protocol = "MPBL3P_UDP"
|
||||||
|
|
||||||
|
local ack = buffer(0, 4):le_uint()
|
||||||
|
local nack = buffer(4, 4):le_uint()
|
||||||
|
local seq = buffer(8, 4):le_uint()
|
||||||
|
|
||||||
|
local unix_time = buffer(buffer:len() - 8, 8):le_uint64()
|
||||||
|
|
||||||
|
local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack)
|
||||||
|
|
||||||
|
subtree:add(ack_F, ack)
|
||||||
|
subtree:add(nack_F, nack)
|
||||||
|
subtree:add(seq_F, seq)
|
||||||
|
subtree:add(time_F, NSTime.new(unix_time:tonumber()))
|
||||||
|
if buffer:len() > 20 then
|
||||||
|
subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
DissectorTable.get("udp.port"):add(4724, mpbl3p_udp)
|
||||||
|
DissectorTable.get("udp.port"):add(4725, mpbl3p_udp)
|
@ -1,13 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
var NextId = make(chan int)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
go func() {
|
|
||||||
i := 0
|
|
||||||
for {
|
|
||||||
NextId <- i
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user