Compare commits

..

4 Commits
develop ... udp

Author SHA1 Message Date
0e6a258106 Merge branch 'develop' into udp
All checks were successful
continuous-integration/drone/push Build is passing
2021-05-13 18:35:23 +01:00
c6e39e5ba1 Merge branch 'develop' into udp
All checks were successful
continuous-integration/drone/push Build is passing
2021-05-12 00:18:48 +01:00
5396fe1416 Merge branch 'develop' into udp
All checks were successful
continuous-integration/drone/push Build is passing
2021-04-06 15:53:52 +01:00
823b787ab9 better handling of udp keepalives
All checks were successful
continuous-integration/drone/push Build is passing
2021-03-31 18:24:56 +01:00
24 changed files with 320 additions and 978 deletions

View File

@ -18,27 +18,6 @@
sysctl -w net.ipv4.conf.all.arp_ignore=1
See http://kb.linuxvirtualserver.org/wiki/Using_arp_announce/arp_ignore_to_disable_ARP
### Systemd unit
[Unit]
Description=NetCombiner for interface %i
After=network-online.target
[Service]
Type=forking
ExecStartPre=/etc/netcombiner/%i.pre
ExecStart=/usr/local/sbin/netcombiner %i
ExecStartPost=/etc/netcombiner/%i.post
User=root
Group=root
Restart=always
[Install]
WantedBy=multi-user.target
### Setup Scripts
These are functional setup scripts that make the application run as intended on Linux.

View File

@ -7,30 +7,22 @@ import (
"mpbl3p/crypto"
"mpbl3p/crypto/sharedkey"
"mpbl3p/proxy"
"mpbl3p/replay"
"mpbl3p/tcp"
"mpbl3p/udp"
"mpbl3p/udp/congestion"
"mpbl3p/udp/congestion/newreno"
"time"
)
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
p := proxy.NewProxy(0)
var gs []func() proxy.MacGenerator
var vs []func() proxy.MacVerifier
if c.Host.ReplayProtection {
rp := replay.NewAntiReplay()
gs = append(gs, func() proxy.MacGenerator { return rp })
vs = append(vs, func() proxy.MacVerifier { return rp })
}
var g func() proxy.MacGenerator
var v func() proxy.MacVerifier
switch c.Host.Crypto {
case "None":
gs = append(gs, func() proxy.MacGenerator { return crypto.None{} })
vs = append(vs, func() proxy.MacVerifier { return crypto.None{} })
g = func() proxy.MacGenerator { return crypto.None{} }
v = func() proxy.MacVerifier { return crypto.None{} }
case "Blake2s":
key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey)
if err != nil {
@ -39,14 +31,14 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
if _, err := sharedkey.NewBlake2s(key); err != nil {
return nil, err
}
gs = append(gs, func() proxy.MacGenerator {
g = func() proxy.MacGenerator {
g, _ := sharedkey.NewBlake2s(key)
return g
})
vs = append(vs, func() proxy.MacVerifier {
}
v = func() proxy.MacVerifier {
v, _ := sharedkey.NewBlake2s(key)
return v
})
}
}
p.Source = source
@ -55,11 +47,11 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
for _, peer := range c.Peers {
switch peer.Method {
case "TCP":
if err := buildTcp(ctx, p, peer, gs, vs); err != nil {
if err := buildTcp(ctx, p, peer, g, v); err != nil {
return nil, err
}
case "UDP":
if err := buildUdp(ctx, p, peer, gs, vs); err != nil {
if err := buildUdp(ctx, p, peer, g, v); err != nil {
return nil, err
}
}
@ -68,13 +60,7 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
return p, nil
}
func buildTcp(
ctx context.Context,
p *proxy.Proxy,
peer Peer,
gs []func() proxy.MacGenerator,
vs []func() proxy.MacVerifier,
) error {
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -83,23 +69,23 @@ func buildTcp(
}
if peer.RemoteHost != "" {
f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), initialiseVerifiers(vs), initialiseGenerators(gs))
f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort))
if err != nil {
return err
}
if !peer.DisableConsumer {
p.AddConsumer(ctx, f)
p.AddConsumer(ctx, f, g())
}
if !peer.DisableProducer {
p.AddProducer(ctx, f)
p.AddProducer(ctx, f, v())
}
return nil
}
err := tcp.NewListener(ctx, p, laddr(), vs, gs, !peer.DisableConsumer, !peer.DisableProducer)
err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil {
return err
}
@ -107,13 +93,7 @@ func buildTcp(
return nil
}
func buildUdp(
ctx context.Context,
p *proxy.Proxy,
peer Peer,
gs []func() proxy.MacGenerator,
vs []func() proxy.MacVerifier,
) error {
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -128,15 +108,15 @@ func buildUdp(
default:
fallthrough
case "NewReno":
c = func() udp.Congestion { return newreno.NewNewReno() }
c = func() udp.Congestion { return congestion.NewNewReno() }
}
if peer.RemoteHost != "" {
f, err := udp.InitiateFlow(
laddr,
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
initialiseVerifiers(vs),
initialiseGenerators(gs),
v(),
g(),
c(),
time.Duration(peer.KeepAlive)*time.Second,
)
@ -146,35 +126,19 @@ func buildUdp(
}
if !peer.DisableConsumer {
p.AddConsumer(ctx, f)
p.AddConsumer(ctx, f, g())
}
if !peer.DisableProducer {
p.AddProducer(ctx, f)
p.AddProducer(ctx, f, v())
}
return nil
}
err := udp.NewListener(ctx, p, laddr(), vs, gs, c, !peer.DisableConsumer, !peer.DisableProducer)
err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil {
return err
}
return nil
}
func initialiseVerifiers(vs []func() proxy.MacVerifier) (out []proxy.MacVerifier) {
out = make([]proxy.MacVerifier, len(vs))
for i, v := range vs {
out[i] = v()
}
return
}
func initialiseGenerators(gs []func() proxy.MacGenerator) (out []proxy.MacGenerator) {
out = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
out[i] = g()
}
return
}

View File

@ -38,10 +38,9 @@ type Configuration struct {
}
type Host struct {
Crypto string `validate:"required,oneof=None Blake2s"`
SharedKey string `validate:"required_if=Crypto Blake2s"`
MTU uint `validate:"required,min=576"`
ReplayProtection bool
Crypto string `validate:"required,oneof=None Blake2s"`
SharedKey string `validate:"required_if=Crypto Blake2s"`
MTU uint `validate:"required,min=576"`
}
type Peer struct {

View File

@ -1,12 +1,14 @@
package flags
import (
"errors"
"fmt"
goflags "github.com/jessevdk/go-flags"
"os"
)
var PrintedHelpErr = goflags.ErrHelp
var NotEnoughArgs = errors.New("not enough arguments")
type Options struct {
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`

View File

@ -1,4 +0,0 @@
package flags
const DefaultConfigFile = "/usr/local/etc/netcombiner/%s"
const DefaultPidFile = "/var/run/netcombiner/%s.pid"

View File

@ -4,22 +4,19 @@ import (
"mpbl3p/shared"
)
type AlmostUselessMac string
type AlmostUselessMac struct{}
func (a AlmostUselessMac) CodeLength() int {
return len(a)
func (AlmostUselessMac) CodeLength() int {
return 4
}
func (a AlmostUselessMac) Generate([]byte) []byte {
return []byte(a)
func (AlmostUselessMac) Generate([]byte) []byte {
return []byte{'a', 'b', 'c', 'd'}
}
func (a AlmostUselessMac) Verify(_, sum []byte) error {
for i, c := range sum {
if a[i] != c {
return shared.ErrBadChecksum
}
func (u AlmostUselessMac) Verify(_, sum []byte) error {
if !(sum[0] == 'a' && sum[1] == 'b' && sum[2] == 'c' && sum[3] == 'd') {
return shared.ErrBadChecksum
}
return nil
}

View File

@ -1,9 +1,6 @@
package mocks
import (
"net"
"time"
)
import "net"
type MockPerfectBiPacketConn struct {
directionA chan []byte
@ -47,10 +44,6 @@ func (c MockPerfectPacketConn) LocalAddr() net.Addr {
}
}
func (c MockPerfectPacketConn) SetReadDeadline(time.Time) error {
return nil
}
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
p := <-c.inbound
return copy(b, p), &net.UDPAddr{

View File

@ -1,9 +0,0 @@
package proxy
import "context"
type Exchange interface {
Initial(ctx context.Context) (out []byte, err error)
Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error)
Complete() bool
}

View File

@ -1,6 +1,9 @@
package proxy
import "mpbl3p/shared"
import (
"encoding/binary"
"time"
)
type Packet interface {
Marshal() []byte
@ -19,15 +22,17 @@ func (p SimplePacket) Contents() []byte {
}
func AppendMac(b []byte, g MacGenerator) []byte {
footer := make([]byte, 8)
unixTime := uint64(time.Now().Unix())
binary.LittleEndian.PutUint64(footer, unixTime)
b = append(b, footer...)
mac := g.Generate(b)
return append(b, mac...)
}
func StripMac(b []byte, v MacVerifier) ([]byte, error) {
if len(b) < v.CodeLength() {
return nil, shared.ErrNotEnoughBytes
}
data := b[:len(b)-v.CodeLength()]
sum := b[len(b)-v.CodeLength():]
@ -35,5 +40,7 @@ func StripMac(b []byte, v MacVerifier) ([]byte, error) {
return nil, err
}
return data, nil
// TODO: Verify timestamp
return data[:len(data)-8], nil
}

View File

@ -10,18 +10,18 @@ import (
func TestAppendMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac("abcd")
testMac := mocks.AlmostUselessMac{}
testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) {
assert.Len(t, appended, len(testMarshalled)+4)
assert.Len(t, appended, len(testMarshalled)+8+4)
})
t.Run("Mac", func(t *testing.T) {
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):])
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:])
})
t.Run("Original", func(t *testing.T) {
@ -31,7 +31,7 @@ func TestAppendMac(t *testing.T) {
func TestStripMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac("abcd")
testMac := mocks.AlmostUselessMac{}
testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal()

View File

@ -9,12 +9,12 @@ import (
type Producer interface {
IsAlive() bool
Produce(context.Context) (Packet, error)
Produce(context.Context, MacVerifier) (Packet, error)
}
type Consumer interface {
IsAlive() bool
Consume(context.Context, Packet) error
Consume(context.Context, Packet, MacGenerator) error
}
type Reconnectable interface {
@ -67,7 +67,7 @@ func (p Proxy) Start() {
}()
}
func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
go func() {
_, reconnectable := c.(Reconnectable)
@ -75,13 +75,12 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
if reconnectable {
var err error
for once := true; err != nil || once; once = false {
if err := ctx.Err(); err != nil {
log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect(ctx)
if !once {
time.Sleep(time.Second)
}
@ -95,7 +94,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
log.Printf("closed consumer `%v` (context)\n", c)
return
case packet := <-p.proxyChan:
if err := c.Consume(ctx, packet); err != nil {
if err := c.Consume(ctx, packet, g); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
@ -111,7 +110,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
}()
}
func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
go func() {
_, reconnectable := pr.(Reconnectable)
@ -119,13 +118,12 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
if reconnectable {
var err error
for once := true; err != nil || once; once = false {
if err := ctx.Err(); err != nil {
log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect(ctx)
if !once {
time.Sleep(time.Second)
}
@ -138,7 +136,7 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
}
for pr.IsAlive() {
if packet, err := pr.Produce(ctx); err != nil {
if packet, err := pr.Produce(ctx, v); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return

View File

@ -5,4 +5,3 @@ import "errors"
var ErrBadChecksum = errors.New("the packet had a bad checksum")
var ErrDeadConnection = errors.New("the connection is dead")
var ErrNotEnoughBytes = errors.New("not enough bytes")
var ErrBadExchange = errors.New("bad exchange")

View File

@ -42,16 +42,10 @@ type Flow struct {
toConsume, produced chan []byte
consumeErrors, produceErrors chan error
generators []proxy.MacGenerator
verifiers []proxy.MacVerifier
}
func NewFlow(vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
func NewFlow() Flow {
return Flow{
verifiers: vs,
generators: gs,
toConsume: make(chan []byte),
produced: make(chan []byte),
consumeErrors: make(chan error),
@ -59,14 +53,11 @@ func NewFlow(vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
}
}
func NewFlowConn(ctx context.Context, conn Conn, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
func NewFlowConn(ctx context.Context, conn Conn) Flow {
f := Flow{
conn: conn,
isAlive: true,
generators: gs,
verifiers: vs,
toConsume: make(chan []byte),
produced: make(chan []byte),
consumeErrors: make(chan error),
@ -87,12 +78,12 @@ func (f *Flow) IsAlive() bool {
return f.isAlive
}
func InitiateFlow(local func() string, remote string, vs []proxy.MacVerifier, gs []proxy.MacGenerator) (*InitiatedFlow, error) {
func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) {
f := InitiatedFlow{
Local: local,
Remote: remote,
Flow: NewFlow(vs, gs),
Flow: NewFlow(),
}
return &f, nil
@ -134,21 +125,21 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
return nil
}
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error {
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
return f.Flow.Consume(ctx, p, g)
}
func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) {
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
return f.Flow.Produce(ctx, v)
}
func (f *Flow) Consume(ctx context.Context, p proxy.Packet) error {
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive {
return shared.ErrDeadConnection
}
@ -160,10 +151,8 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet) error {
default:
}
data := p.Marshal()
for _, g := range f.generators {
data = proxy.AppendMac(data, g)
}
marshalled := p.Marshal()
data := proxy.AppendMac(marshalled, g)
prefixedData := make([]byte, len(data)+4)
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
@ -178,7 +167,7 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet) error {
return nil
}
func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) {
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
@ -194,17 +183,12 @@ func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) {
return nil, err
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
var err error
data, err = proxy.StripMac(data, v)
if err != nil {
return nil, err
}
b, err := proxy.StripMac(data, v)
if err != nil {
return nil, err
}
return proxy.SimplePacket(data), nil
return proxy.SimplePacket(b), nil
}
func (f *Flow) consumeMarshalled(ctx context.Context) {

View File

@ -13,75 +13,41 @@ import (
func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac("abcd")
testMac2 := mocks.AlmostUselessMac("efgh")
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA := NewFlowConn(context.Background(), testConn.SideA())
err := flowA.Consume(context.Background(), testPacket)
err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf)
require.Nil(t, err)
assert.Equal(t, len(testContent)+4+4, n)
assert.Equal(t, uint32(len(testContent)+4), binary.LittleEndian.Uint32(buf[:len(buf)-4]))
})
t.Run("MultipleGeneratorsLength", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf)
require.Nil(t, err)
assert.Equal(t, len(testContent)+4+4+4, n)
assert.Equal(t, uint32(len(testContent)+4+4), binary.LittleEndian.Uint32(buf[:len(buf)-4]))
})
t.Run("MultipleGeneratorsOrder", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf)
require.Nil(t, err)
assert.Equal(t, len(testContent)+4+4+4, n)
assert.Equal(t, "abcdefgh", string(buf[n-8:n]))
assert.Equal(t, len(testContent)+8+4+4, n)
assert.Equal(t, uint32(len(testContent)+8+4), binary.LittleEndian.Uint32(buf[:len(buf)-4]))
})
}
func TestFlow_Produce(t *testing.T) {
testContent := "A test string is the content of this packet."
testMarshalled := []byte("0000" + testContent + "abcd")
testMarshalled := []byte("0000" + testContent + "00000000abcd")
binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4))
testMac := mocks.AlmostUselessMac("abcd")
testMac2 := mocks.AlmostUselessMac("efgh")
testMac := mocks.AlmostUselessMac{}
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA := NewFlowConn(context.Background(), testConn.SideA())
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(context.Background())
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Contents()))
})
@ -89,29 +55,12 @@ func TestFlow_Produce(t *testing.T) {
t.Run("Value", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA := NewFlowConn(context.Background(), testConn.SideA())
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(context.Background())
require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents()))
})
t.Run("MultipleVerifiersStrip", func(t *testing.T) {
testContent := "A test string is the content of this packet."
testMarshalled := []byte("0000" + testContent + "abcdefgh")
binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4))
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(context.Background())
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents()))
})

View File

@ -7,15 +7,7 @@ import (
"net"
)
func NewListener(
ctx context.Context,
p *proxy.Proxy,
local string,
vs []func() proxy.MacVerifier,
gs []func() proxy.MacGenerator,
enableConsumers bool,
enableProducers bool,
) error {
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error {
laddr, err := net.ResolveTCPAddr("tcp", local)
if err != nil {
return err
@ -37,24 +29,15 @@ func NewListener(
panic(err)
}
var verifiers = make([]proxy.MacVerifier, len(vs))
for i, v := range vs {
verifiers[i] = v()
}
var generators = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
generators[i] = g()
}
f := NewFlowConn(ctx, conn, verifiers, generators)
f := NewFlowConn(ctx, conn)
log.Printf("received new tcp connection: %v\n", f)
if enableConsumers {
p.AddConsumer(ctx, &f)
p.AddConsumer(ctx, &f, g())
}
if enableProducers {
p.AddProducer(ctx, &f)
p.AddProducer(ctx, &f, v())
}
}
}()

View File

@ -1,4 +1,4 @@
package newreno
package congestion
import (
"context"
@ -11,12 +11,10 @@ import (
)
const RttExponentialFactor = 0.1
const RttLossDelay = 0.5
const RttLossDelay = 1.5
type NewReno struct {
sequence chan uint32
stopSequence context.CancelFunc
wasInitial, alive bool
sequence chan uint32
inFlight []flightInfo
lastSent time.Time
@ -66,6 +64,19 @@ func NewNewReno() *NewReno {
slowStart: true,
}
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c
}

View File

@ -1,119 +0,0 @@
package newreno
import (
"context"
"encoding/binary"
"math/rand"
"mpbl3p/shared"
"time"
)
func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) {
c.alive = false
c.wasInitial = true
c.startSequenceLoop(ctx)
var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return nil, ctx.Err()
}
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[8:12], s)
c.inFlight = []flightInfo{{time.Now(), s}}
return b, nil
}
func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) {
if c.alive || c.stopSequence == nil {
// reset
c.alive = false
c.startSequenceLoop(ctx)
}
// receive
if len(in) != 12 {
return nil, nil, shared.ErrBadExchange
}
rcvAck := binary.LittleEndian.Uint32(in[0:4])
rcvNack := binary.LittleEndian.Uint32(in[4:8])
rcvSeq := binary.LittleEndian.Uint32(in[8:12])
// verify
if rcvNack != 0 {
return nil, nil, shared.ErrBadExchange
}
var seq uint32
if c.wasInitial {
if rcvAck == c.inFlight[0].sequence {
c.ack, c.lastAck = rcvSeq, rcvSeq
c.alive, c.inFlight = true, nil
} else {
return nil, nil, shared.ErrBadExchange
}
} else { // if !c.wasInitial
if rcvAck == 0 {
// theirs is a syn packet
c.ack, c.lastAck = rcvSeq, rcvSeq
select {
case seq = <-c.sequence:
case <-ctx.Done():
return nil, nil, ctx.Err()
}
c.inFlight = []flightInfo{{time.Now(), seq}}
} else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence {
c.alive, c.inFlight = true, nil
} else {
return nil, nil, shared.ErrBadExchange
}
}
// respond
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[0:4], c.ack)
binary.LittleEndian.PutUint32(b[8:12], seq)
return b, nil, nil
}
func (c *NewReno) Complete() bool {
return c.alive
}
func (c *NewReno) startSequenceLoop(ctx context.Context) {
if c.stopSequence != nil {
c.stopSequence()
}
var s uint32
for s == 0 {
s = rand.Uint32()
}
ctx, c.stopSequence = context.WithCancel(ctx)
go func() {
s := s
for {
if s == 0 {
s++
continue
}
select {
case c.sequence <- s:
case <-ctx.Done():
return
}
s++
}
}()
}

View File

@ -1,57 +0,0 @@
package newreno
import (
"context"
"encoding/binary"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestNewReno_InitialHandle(t *testing.T) {
t.Run("InitialAckNackAreZero", func(t *testing.T) {
// ASSIGN
a := NewNewReno()
ctx := context.Background()
// ACT
initial, err := a.Initial(ctx)
ack := binary.LittleEndian.Uint32(initial[0:4])
nack := binary.LittleEndian.Uint32(initial[4:8])
// ASSERT
require.Nil(t, err)
assert.Zero(t, ack)
assert.Zero(t, nack)
})
t.Run("InitialHandledWithAck", func(t *testing.T) {
// ASSIGN
a := NewNewReno()
b := NewNewReno()
ctx := context.Background()
// ACT
initial, err := a.Initial(ctx)
require.Nil(t, err)
initialSeq := binary.LittleEndian.Uint32(initial[8:12])
response, data, err := b.Handle(ctx, initial)
ack := binary.LittleEndian.Uint32(response[0:4])
nack := binary.LittleEndian.Uint32(response[4:8])
// ASSERT
require.Nil(t, err)
assert.Equal(t, initialSeq, ack)
assert.Zero(t, nack)
assert.Nil(t, data)
})
}

View File

@ -1,4 +1,4 @@
package newreno
package congestion
import (
"context"
@ -21,8 +21,8 @@ type newRenoTest struct {
halfRtt time.Duration
}
func newNewRenoTest(ctx context.Context, rtt time.Duration) *newRenoTest {
nr := &newRenoTest{
func newNewRenoTest(rtt time.Duration) *newRenoTest {
return &newRenoTest{
sideA: NewNewReno(),
sideB: NewNewReno(),
@ -34,15 +34,6 @@ func newNewRenoTest(ctx context.Context, rtt time.Duration) *newRenoTest {
halfRtt: rtt / 2,
}
p, _ := nr.sideA.Initial(ctx)
p, _, _ = nr.sideB.Handle(ctx, p)
p, _, _ = nr.sideA.Handle(ctx, p)
nr.sideB.ReceivedPacket(0, nr.sideA.NextAck(), nr.sideA.NextNack())
nr.sideA.ReceivedPacket(0, nr.sideB.NextAck(), nr.sideB.NextNack())
return nr
}
func (n *newRenoTest) Start(ctx context.Context) {
@ -160,14 +151,11 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(ctx, rtt)
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
@ -187,10 +175,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, sideAinitialAck, c.sideA.ack)
assert.Equal(t, uint32(0), c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
t.Run("SequenceLoss", func(t *testing.T) {
@ -201,21 +189,18 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(ctx, rtt)
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT
for i := 1; i <= numPackets; i++ {
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
if i == 20 {
if seq == 20 {
// Simulate packet loss of sequence 20
continue
}
@ -232,10 +217,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, sideAinitialAck, c.sideA.ack)
assert.Equal(t, uint32(0), c.sideA.ack)
assert.Equal(t, sideBinitialAck+uint32(20), c.sideB.nack)
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
assert.Equal(t, uint32(20), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
})
@ -248,19 +233,16 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(ctx, rtt)
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT
done := make(chan struct{})
go func() {
for i := 1; i <= numPackets; i++ {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
@ -275,7 +257,7 @@ func TestNewReno_Congestion(t *testing.T) {
}()
go func() {
for i := 1; i <= numPackets; i++ {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx)
@ -297,10 +279,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT
assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
t.Run("SequenceLoss", func(t *testing.T) {
@ -311,23 +293,20 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := newNewRenoTest(ctx, rtt)
c := newNewRenoTest(rtt)
c.Start(ctx)
c.RunSideA(ctx)
c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT
done := make(chan struct{})
go func() {
for i := 1; i <= numPackets; i++ {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx)
if i == 9 {
if seq == 9 {
// Simulate packet loss of sequence 9
continue
}
@ -343,11 +322,11 @@ func TestNewReno_Congestion(t *testing.T) {
}()
go func() {
for i := 1; i <= numPackets; i++ {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx)
if i == 13 {
if seq == 13 {
// Simulate packet loss of sequence 13
continue
}
@ -369,11 +348,11 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT
assert.Equal(t, sideAinitialAck+uint32(13), c.sideA.nack)
assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(13), c.sideA.nack)
assert.Equal(t, uint32(numPackets), c.sideA.ack)
assert.Equal(t, sideBinitialAck+uint32(9), c.sideB.nack)
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
assert.Equal(t, uint32(9), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack)
})
})
}

View File

@ -7,6 +7,7 @@ import (
"mpbl3p/proxy"
"mpbl3p/shared"
"net"
"sync"
"time"
)
@ -18,19 +19,33 @@ type PacketWriter interface {
type PacketConn interface {
PacketWriter
SetReadDeadline(t time.Time) error
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
}
type InitiatedFlow struct {
Local func() string
Remote string
g proxy.MacGenerator
keepalive time.Duration
mu sync.RWMutex
Flow
}
func (f *InitiatedFlow) String() string {
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
}
type Flow struct {
writer PacketWriter
raddr *net.UDPAddr
isAlive bool
startup bool
congestion Congestion
verifiers []proxy.MacVerifier
generators []proxy.MacGenerator
v proxy.MacVerifier
inboundDatagrams chan []byte
}
@ -39,20 +54,134 @@ func (f Flow) String() string {
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
}
func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
func InitiateFlow(
local func() string,
remote string,
v proxy.MacVerifier,
g proxy.MacGenerator,
c Congestion,
keepalive time.Duration,
) (*InitiatedFlow, error) {
f := InitiatedFlow{
Local: local,
Remote: remote,
Flow: newFlow(c, v),
g: g,
keepalive: keepalive,
}
return &f, nil
}
func newFlow(c Congestion, v proxy.MacVerifier) Flow {
return Flow{
inboundDatagrams: make(chan []byte),
congestion: c,
verifiers: vs,
generators: gs,
v: v,
}
}
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
}
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return err
}
f.writer = conn
f.startup = true
// prod the connection once a second until we get an ack, then consider it alive
go func() {
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return
}
for !f.isAlive {
if ctx.Err() != nil {
return
}
p := Packet{
ack: 0,
nack: 0,
seq: seq,
data: proxy.SimplePacket(nil),
}
_ = f.sendPacket(p, f.g)
time.Sleep(1 * time.Second)
}
}()
go func() {
_, _ = f.produceInternal(ctx, f.v, false)
}()
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
if err := f.readQueuePacket(ctx, conn); err != nil {
return err
}
f.isAlive = true
f.startup = false
go func() {
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
if err := f.readQueuePacket(ctx, conn); err != nil {
log.Println(err)
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
return nil
}
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p, g)
}
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx, v)
}
func (f *Flow) IsAlive() bool {
return f.isAlive
}
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet) error {
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive {
return shared.ErrDeadConnection
}
@ -75,18 +204,18 @@ func (f *Flow) Consume(ctx context.Context, pp proxy.Packet) error {
nack: f.congestion.NextNack(),
}
return f.sendPacket(p)
return f.sendPacket(p, g)
}
func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) {
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
return f.produceInternal(ctx, true)
return f.produceInternal(ctx, v, true)
}
func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Packet, error) {
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
for once := true; mustReturn || once; once = false {
log.Println(f.congestion)
@ -97,17 +226,12 @@ func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Pack
return nil, ctx.Err()
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
var err error
received, err = proxy.StripMac(received, v)
if err != nil {
return nil, err
}
b, err := proxy.StripMac(received, v)
if err != nil {
return nil, err
}
p, err := UnmarshalPacket(received)
p, err := UnmarshalPacket(b)
if err != nil {
return nil, err
}
@ -115,7 +239,6 @@ func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Pack
// adjust congestion control based on this packet's congestion header
f.congestion.ReceivedPacket(p.seq, p.nack, p.ack)
// 12 bytes for header + the MAC + a timestamp
if len(p.Contents()) == 0 {
log.Println("handled keepalive/ack only packet")
continue
@ -136,12 +259,9 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
}
}
func (f *Flow) sendPacket(p proxy.Packet) error {
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
b := p.Marshal()
for _, g := range f.generators {
b = proxy.AppendMac(b, g)
}
b = proxy.AppendMac(b, g)
if f.raddr == nil {
_, err := f.writer.Write(b)
@ -152,7 +272,7 @@ func (f *Flow) sendPacket(p proxy.Packet) error {
}
}
func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) {
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
for f.isAlive {
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
if err != nil {
@ -166,31 +286,20 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) {
nack: f.congestion.NextNack(),
}
err = f.sendPacket(p)
err = f.sendPacket(p, g)
if err != nil {
fmt.Printf("error sending early update packet: `%v`\n", err)
}
}
}
func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) {
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000)
if d, ok := ctx.Deadline(); ok {
if err := c.SetReadDeadline(d); err != nil {
return nil, err
}
}
n, _, err := c.ReadFromUDP(buf)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
if ctx.Err() != nil {
return nil, ctx.Err()
}
}
return nil, err
return err
}
return buf[:n], nil
return f.queueDatagram(ctx, buf[:n])
}

View File

@ -15,17 +15,17 @@ import (
func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac("abcd")
testMac := mocks.AlmostUselessMac{}
t.Run("SingleGeneratorLength", func(t *testing.T) {
t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket)
err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
@ -33,50 +33,7 @@ func TestFlow_Consume(t *testing.T) {
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+4, n)
})
t.Run("MultipleGeneratorsLength", func(t *testing.T) {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+4+4, n)
})
t.Run("MultipleGeneratorsOrder", func(t *testing.T) {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
require.Equal(t, len(testContent)+12+4+4, n)
macs := string(buf[n-8 : n])
assert.Equal(t, "abcdefgh", macs)
assert.Equal(t, len(testContent)+12+8+4, n)
})
}
@ -88,7 +45,7 @@ func TestFlow_Produce(t *testing.T) {
seq: 128,
data: proxy.SimplePacket(testContent),
}
testMac := mocks.AlmostUselessMac("abcd")
testMac := mocks.AlmostUselessMac{}
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
@ -101,56 +58,16 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideA().Write(testMarshalled)
require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA := newFlow(congestion.NewNone(), testMac)
flowA.writer = testConn.SideB()
flowA.isAlive = true
go func() {
p, err := flowA.readPacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
err = flowA.queueDatagram(context.Background(), p)
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
}()
p, err := flowA.Produce(context.Background())
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
done <- struct{}{}
}()
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-done:
case <-timer.C:
fmt.Println("timed out")
t.FailNow()
}
})
t.Run("MultipleVerifiersStrip", func(t *testing.T) {
done := make(chan struct{})
go func() {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
_, err := testConn.SideA().Write(proxy.AppendMac(testMarshalled, testMac2))
require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
go func() {
p, err := flowA.readPacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
err = flowA.queueDatagram(context.Background(), p)
assert.Nil(t, err)
}()
p, err := flowA.Produce(context.Background())
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))

View File

@ -1,142 +0,0 @@
package udp
import (
"context"
"log"
"mpbl3p/proxy"
"sync"
"time"
)
type InboundFlow struct {
inboundDatagrams chan []byte
mu sync.RWMutex
Flow
}
func newInboundFlow(f Flow) (*InboundFlow, error) {
fi := InboundFlow{
inboundDatagrams: make(chan []byte),
Flow: f,
}
return &fi, nil
}
func (f *InboundFlow) queueDatagram(ctx context.Context, p []byte) error {
select {
case f.inboundDatagrams <- p:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *InboundFlow) processPackets(ctx context.Context) {
for {
f.mu.Lock()
var err error
for once := true; err != nil || once; once = false {
if ctx.Err() != nil {
return
}
err = f.handleExchanges(ctx)
if err != nil {
log.Println(err)
}
}
f.mu.Unlock()
var p []byte
select {
case p = <-f.inboundDatagrams:
case <-ctx.Done():
return
}
// TODO: Check if p means redo exchanges
if false {
continue
}
select {
case f.Flow.inboundDatagrams <- p:
case <-ctx.Done():
return
}
}
}
func (f *InboundFlow) handleExchanges(ctx context.Context) error {
var exchanges []proxy.Exchange
if e, ok := f.congestion.(proxy.Exchange); ok {
exchanges = append(exchanges, e)
}
var exchangeData [][]byte
for _, e := range exchanges {
for once := true; !e.Complete() || once; once = false {
if err := func() (err error) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var recv []byte
select {
case recv = <-f.inboundDatagrams:
case <-ctx.Done():
return ctx.Err()
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
recv, err = proxy.StripMac(recv, v)
if err != nil {
return err
}
}
var resp, data []byte
if resp, data, err = e.Handle(ctx, recv); err != nil {
return err
}
if data != nil {
exchangeData = append(exchangeData, data)
}
if resp != nil {
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
}
return nil
}
func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
}
func (f *InboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
}

View File

@ -16,7 +16,9 @@ type ComparableUdpAddress struct {
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
var ip [16]byte
copy(ip[:], address.IP)
for i, b := range []byte(address.IP) {
ip[i] = b
}
return ComparableUdpAddress{
IP: ip,
@ -25,16 +27,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
}
}
func NewListener(
ctx context.Context,
p *proxy.Proxy,
local string,
vs []func() proxy.MacVerifier,
gs []func() proxy.MacGenerator,
c func() Congestion,
enableConsumers bool,
enableProducers bool,
) error {
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error {
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
@ -45,7 +38,12 @@ func NewListener(
return err
}
receivedConnections := make(map[ComparableUdpAddress]*InboundFlow)
err = pconn.SetWriteBuffer(0)
if err != nil {
panic(err)
}
receivedConnections := make(map[ComparableUdpAddress]*Flow)
go func() {
for ctx.Err() == nil {
@ -63,53 +61,39 @@ func NewListener(
}
raddr := fromUdpAddress(*addr)
if fi, exists := receivedConnections[raddr]; exists {
if f, exists := receivedConnections[raddr]; exists {
log.Println("existing flow. queuing...")
if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
log.Println("error")
continue
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
}
log.Println("queued")
continue
}
var verifiers = make([]proxy.MacVerifier, len(vs))
for i, v := range vs {
verifiers[i] = v()
}
var generators = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
generators[i] = g()
}
v := v()
g := g()
f := newFlow(c(), verifiers, generators)
f := newFlow(c(), v)
f.writer = pconn
f.raddr = addr
f.isAlive = true
fi, err := newInboundFlow(f)
if err != nil {
log.Println(err)
continue
}
log.Printf("received new udp connection: %v\n", f)
go fi.processPackets(ctx)
go fi.earlyUpdateLoop(ctx, 0)
go f.earlyUpdateLoop(ctx, g, 0)
receivedConnections[raddr] = fi
receivedConnections[raddr] = &f
if enableConsumers {
p.AddConsumer(ctx, fi)
p.AddConsumer(ctx, &f, g)
}
if enableProducers {
p.AddProducer(ctx, fi)
p.AddProducer(ctx, &f, v)
}
log.Println("handling...")
if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
return
}
log.Println("handled")

View File

@ -1,181 +0,0 @@
package udp
import (
"context"
"errors"
"fmt"
"log"
"mpbl3p/proxy"
"net"
"sync"
"time"
)
type OutboundFlow struct {
Local func() string
Remote string
g proxy.MacGenerator
keepalive time.Duration
mu sync.RWMutex
Flow
}
func InitiateFlow(
local func() string,
remote string,
vs []proxy.MacVerifier,
gs []proxy.MacGenerator,
c Congestion,
keepalive time.Duration,
) (*OutboundFlow, error) {
f := OutboundFlow{
Local: local,
Remote: remote,
Flow: newFlow(c, vs, gs),
keepalive: keepalive,
}
return &f, nil
}
func (f *OutboundFlow) String() string {
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
}
func (f *OutboundFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
}
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return err
}
f.writer = conn
// prod the connection once a second until we get an ack, then consider it alive
var exchanges []proxy.Exchange
if e, ok := f.congestion.(proxy.Exchange); ok {
exchanges = append(exchanges, e)
}
var exchangeData [][]byte
for _, e := range exchanges {
i, err := e.Initial(ctx)
if err != nil {
return err
}
if err = f.sendPacket(proxy.SimplePacket(i)); err != nil {
return err
}
for once := true; !e.Complete() || once; once = false {
if err := func() error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var recv []byte
if recv, err = f.readPacket(ctx, conn); err != nil {
return err
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
recv, err = proxy.StripMac(recv, v)
if err != nil {
return err
}
}
var resp, data []byte
if resp, data, err = e.Handle(ctx, recv); err != nil {
return err
}
if data != nil {
exchangeData = append(exchangeData, data)
}
if resp != nil {
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
}
go func() {
for _, d := range exchangeData {
if err := f.queueDatagram(ctx, d); err != nil {
return
}
}
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
var p []byte
if p, err = f.readPacket(ctx, conn); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return
}
log.Println(err)
return
}
if err := f.queueDatagram(ctx, p); err != nil {
return
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
f.isAlive = true
return nil
}
func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
}
func (f *OutboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
}