Merge pull request 'exchanges' (#20) from exchanges into develop
All checks were successful
continuous-integration/drone/push Build is passing

Reviewed-on: #20
This commit is contained in:
JakeHillion 2021-05-14 07:10:45 +00:00
commit 1b4288e9db
16 changed files with 626 additions and 208 deletions

View File

@ -11,6 +11,7 @@ import (
"mpbl3p/tcp" "mpbl3p/tcp"
"mpbl3p/udp" "mpbl3p/udp"
"mpbl3p/udp/congestion" "mpbl3p/udp/congestion"
"mpbl3p/udp/congestion/newreno"
"time" "time"
) )
@ -127,7 +128,7 @@ func buildUdp(
default: default:
fallthrough fallthrough
case "NewReno": case "NewReno":
c = func() udp.Congestion { return congestion.NewNewReno() } c = func() udp.Congestion { return newreno.NewNewReno() }
} }
if peer.RemoteHost != "" { if peer.RemoteHost != "" {

View File

@ -1,14 +1,12 @@
package flags package flags
import ( import (
"errors"
"fmt" "fmt"
goflags "github.com/jessevdk/go-flags" goflags "github.com/jessevdk/go-flags"
"os" "os"
) )
var PrintedHelpErr = goflags.ErrHelp var PrintedHelpErr = goflags.ErrHelp
var NotEnoughArgs = errors.New("not enough arguments")
type Options struct { type Options struct {
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"` Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`

4
flags/locs_darwin.go Normal file
View File

@ -0,0 +1,4 @@
package flags
const DefaultConfigFile = "/usr/local/etc/netcombiner/%s"
const DefaultPidFile = "/var/run/netcombiner/%s.pid"

View File

@ -1,6 +1,9 @@
package mocks package mocks
import "net" import (
"net"
"time"
)
type MockPerfectBiPacketConn struct { type MockPerfectBiPacketConn struct {
directionA chan []byte directionA chan []byte
@ -44,6 +47,10 @@ func (c MockPerfectPacketConn) LocalAddr() net.Addr {
} }
} }
func (c MockPerfectPacketConn) SetReadDeadline(time.Time) error {
return nil
}
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
p := <-c.inbound p := <-c.inbound
return copy(b, p), &net.UDPAddr{ return copy(b, p), &net.UDPAddr{

9
proxy/exchange.go Normal file
View File

@ -0,0 +1,9 @@
package proxy
import "context"
type Exchange interface {
Initial(ctx context.Context) (out []byte, err error)
Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error)
Complete() bool
}

View File

@ -75,12 +75,13 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
if reconnectable { if reconnectable {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect consumer `%v`\n", c) if err := ctx.Err(); err != nil {
err = c.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c) log.Printf("closed consumer `%v` (context)\n", c)
return return
} }
log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect(ctx)
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -118,12 +119,13 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
if reconnectable { if reconnectable {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect producer `%v`\n", pr) if err := ctx.Err(); err != nil {
err = pr.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr) log.Printf("closed producer `%v` (context)\n", pr)
return return
} }
log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect(ctx)
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }

View File

@ -5,3 +5,4 @@ import "errors"
var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrBadChecksum = errors.New("the packet had a bad checksum")
var ErrDeadConnection = errors.New("the connection is dead") var ErrDeadConnection = errors.New("the connection is dead")
var ErrNotEnoughBytes = errors.New("not enough bytes") var ErrNotEnoughBytes = errors.New("not enough bytes")
var ErrBadExchange = errors.New("bad exchange")

View File

@ -0,0 +1,119 @@
package newreno
import (
"context"
"encoding/binary"
"math/rand"
"mpbl3p/shared"
"time"
)
func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) {
c.alive = false
c.wasInitial = true
c.startSequenceLoop(ctx)
var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return nil, ctx.Err()
}
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[8:12], s)
c.inFlight = []flightInfo{{time.Now(), s}}
return b, nil
}
func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) {
if c.alive || c.stopSequence == nil {
// reset
c.alive = false
c.startSequenceLoop(ctx)
}
// receive
if len(in) != 12 {
return nil, nil, shared.ErrBadExchange
}
rcvAck := binary.LittleEndian.Uint32(in[0:4])
rcvNack := binary.LittleEndian.Uint32(in[4:8])
rcvSeq := binary.LittleEndian.Uint32(in[8:12])
// verify
if rcvNack != 0 {
return nil, nil, shared.ErrBadExchange
}
var seq uint32
if c.wasInitial {
if rcvAck == c.inFlight[0].sequence {
c.ack, c.lastAck = rcvSeq, rcvSeq
c.alive, c.inFlight = true, nil
} else {
return nil, nil, shared.ErrBadExchange
}
} else { // if !c.wasInitial
if rcvAck == 0 {
// theirs is a syn packet
c.ack, c.lastAck = rcvSeq, rcvSeq
select {
case seq = <-c.sequence:
case <-ctx.Done():
return nil, nil, ctx.Err()
}
c.inFlight = []flightInfo{{time.Now(), seq}}
} else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence {
c.alive, c.inFlight = true, nil
} else {
return nil, nil, shared.ErrBadExchange
}
}
// respond
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[0:4], c.ack)
binary.LittleEndian.PutUint32(b[8:12], seq)
return b, nil, nil
}
func (c *NewReno) Complete() bool {
return c.alive
}
func (c *NewReno) startSequenceLoop(ctx context.Context) {
if c.stopSequence != nil {
c.stopSequence()
}
var s uint32
for s == 0 {
s = rand.Uint32()
}
ctx, c.stopSequence = context.WithCancel(ctx)
go func() {
s := s
for {
if s == 0 {
s++
continue
}
select {
case c.sequence <- s:
case <-ctx.Done():
return
}
s++
}
}()
}

View File

@ -0,0 +1,57 @@
package newreno
import (
"context"
"encoding/binary"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestNewReno_InitialHandle(t *testing.T) {
t.Run("InitialAckNackAreZero", func(t *testing.T) {
// ASSIGN
a := NewNewReno()
ctx := context.Background()
// ACT
initial, err := a.Initial(ctx)
ack := binary.LittleEndian.Uint32(initial[0:4])
nack := binary.LittleEndian.Uint32(initial[4:8])
// ASSERT
require.Nil(t, err)
assert.Zero(t, ack)
assert.Zero(t, nack)
})
t.Run("InitialHandledWithAck", func(t *testing.T) {
// ASSIGN
a := NewNewReno()
b := NewNewReno()
ctx := context.Background()
// ACT
initial, err := a.Initial(ctx)
require.Nil(t, err)
initialSeq := binary.LittleEndian.Uint32(initial[8:12])
response, data, err := b.Handle(ctx, initial)
ack := binary.LittleEndian.Uint32(response[0:4])
nack := binary.LittleEndian.Uint32(response[4:8])
// ASSERT
require.Nil(t, err)
assert.Equal(t, initialSeq, ack)
assert.Zero(t, nack)
assert.Nil(t, data)
})
}

View File

@ -1,4 +1,4 @@
package congestion package newreno
import ( import (
"context" "context"
@ -11,10 +11,12 @@ import (
) )
const RttExponentialFactor = 0.1 const RttExponentialFactor = 0.1
const RttLossDelay = 1.5 const RttLossDelay = 0.5
type NewReno struct { type NewReno struct {
sequence chan uint32 sequence chan uint32
stopSequence context.CancelFunc
wasInitial, alive bool
inFlight []flightInfo inFlight []flightInfo
lastSent time.Time lastSent time.Time
@ -64,19 +66,6 @@ func NewNewReno() *NewReno {
slowStart: true, slowStart: true,
} }
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c return &c
} }

View File

@ -1,4 +1,4 @@
package congestion package newreno
import ( import (
"context" "context"
@ -21,8 +21,8 @@ type newRenoTest struct {
halfRtt time.Duration halfRtt time.Duration
} }
func newNewRenoTest(rtt time.Duration) *newRenoTest { func newNewRenoTest(ctx context.Context, rtt time.Duration) *newRenoTest {
return &newRenoTest{ nr := &newRenoTest{
sideA: NewNewReno(), sideA: NewNewReno(),
sideB: NewNewReno(), sideB: NewNewReno(),
@ -34,6 +34,15 @@ func newNewRenoTest(rtt time.Duration) *newRenoTest {
halfRtt: rtt / 2, halfRtt: rtt / 2,
} }
p, _ := nr.sideA.Initial(ctx)
p, _, _ = nr.sideB.Handle(ctx, p)
p, _, _ = nr.sideA.Handle(ctx, p)
nr.sideB.ReceivedPacket(0, nr.sideA.NextAck(), nr.sideA.NextNack())
nr.sideA.ReceivedPacket(0, nr.sideB.NextAck(), nr.sideB.NextNack())
return nr
} }
func (n *newRenoTest) Start(ctx context.Context) { func (n *newRenoTest) Start(ctx context.Context) {
@ -151,11 +160,14 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
c := newNewRenoTest(rtt) c := newNewRenoTest(ctx, rtt)
c.Start(ctx) c.Start(ctx)
c.RunSideA(ctx) c.RunSideA(ctx)
c.RunSideB(ctx) c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT // ACT
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet // sleep to simulate preparing packet
@ -175,10 +187,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT // ASSERT
assert.Equal(t, uint32(0), c.sideA.nack) assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(0), c.sideA.ack) assert.Equal(t, sideAinitialAck, c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack) assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack) assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
}) })
t.Run("SequenceLoss", func(t *testing.T) { t.Run("SequenceLoss", func(t *testing.T) {
@ -189,18 +201,21 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
c := newNewRenoTest(rtt) c := newNewRenoTest(ctx, rtt)
c.Start(ctx) c.Start(ctx)
c.RunSideA(ctx) c.RunSideA(ctx)
c.RunSideB(ctx) c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT // ACT
for i := 0; i < numPackets; i++ { for i := 1; i <= numPackets; i++ {
// sleep to simulate preparing packet // sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx) seq, _ := c.sideA.Sequence(ctx)
if seq == 20 { if i == 20 {
// Simulate packet loss of sequence 20 // Simulate packet loss of sequence 20
continue continue
} }
@ -217,10 +232,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT // ASSERT
assert.Equal(t, uint32(0), c.sideA.nack) assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(0), c.sideA.ack) assert.Equal(t, sideAinitialAck, c.sideA.ack)
assert.Equal(t, uint32(20), c.sideB.nack) assert.Equal(t, sideBinitialAck+uint32(20), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack) assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
}) })
}) })
@ -233,16 +248,19 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
c := newNewRenoTest(rtt) c := newNewRenoTest(ctx, rtt)
c.Start(ctx) c.Start(ctx)
c.RunSideA(ctx) c.RunSideA(ctx)
c.RunSideB(ctx) c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT // ACT
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 1; i <= numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx) seq, _ := c.sideA.Sequence(ctx)
@ -257,7 +275,7 @@ func TestNewReno_Congestion(t *testing.T) {
}() }()
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 1; i <= numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx) seq, _ := c.sideB.Sequence(ctx)
@ -279,10 +297,10 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT // ASSERT
assert.Equal(t, uint32(0), c.sideA.nack) assert.Equal(t, uint32(0), c.sideA.nack)
assert.Equal(t, uint32(numPackets), c.sideA.ack) assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(0), c.sideB.nack) assert.Equal(t, uint32(0), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack) assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
}) })
t.Run("SequenceLoss", func(t *testing.T) { t.Run("SequenceLoss", func(t *testing.T) {
@ -293,20 +311,23 @@ func TestNewReno_Congestion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
c := newNewRenoTest(rtt) c := newNewRenoTest(ctx, rtt)
c.Start(ctx) c.Start(ctx)
c.RunSideA(ctx) c.RunSideA(ctx)
c.RunSideB(ctx) c.RunSideB(ctx)
sideAinitialAck := c.sideA.ack
sideBinitialAck := c.sideB.ack
// ACT // ACT
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 1; i <= numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq, _ := c.sideA.Sequence(ctx) seq, _ := c.sideA.Sequence(ctx)
if seq == 9 { if i == 9 {
// Simulate packet loss of sequence 9 // Simulate packet loss of sequence 9
continue continue
} }
@ -322,11 +343,11 @@ func TestNewReno_Congestion(t *testing.T) {
}() }()
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 1; i <= numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq, _ := c.sideB.Sequence(ctx) seq, _ := c.sideB.Sequence(ctx)
if seq == 13 { if i == 13 {
// Simulate packet loss of sequence 13 // Simulate packet loss of sequence 13
continue continue
} }
@ -348,11 +369,11 @@ func TestNewReno_Congestion(t *testing.T) {
// ASSERT // ASSERT
assert.Equal(t, uint32(13), c.sideA.nack) assert.Equal(t, sideAinitialAck+uint32(13), c.sideA.nack)
assert.Equal(t, uint32(numPackets), c.sideA.ack) assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
assert.Equal(t, uint32(9), c.sideB.nack) assert.Equal(t, sideBinitialAck+uint32(9), c.sideB.nack)
assert.Equal(t, uint32(numPackets), c.sideB.ack) assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
}) })
}) })
} }

View File

@ -7,7 +7,6 @@ import (
"mpbl3p/proxy" "mpbl3p/proxy"
"mpbl3p/shared" "mpbl3p/shared"
"net" "net"
"sync"
"time" "time"
) )
@ -19,29 +18,15 @@ type PacketWriter interface {
type PacketConn interface { type PacketConn interface {
PacketWriter PacketWriter
SetReadDeadline(t time.Time) error
ReadFromUDP(b []byte) (int, *net.UDPAddr, error) ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
} }
type InitiatedFlow struct {
Local func() string
Remote string
keepalive time.Duration
mu sync.RWMutex
Flow
}
func (f *InitiatedFlow) String() string {
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
}
type Flow struct { type Flow struct {
writer PacketWriter writer PacketWriter
raddr *net.UDPAddr raddr *net.UDPAddr
isAlive bool isAlive bool
startup bool
congestion Congestion congestion Congestion
verifiers []proxy.MacVerifier verifiers []proxy.MacVerifier
@ -54,24 +39,6 @@ func (f Flow) String() string {
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
} }
func InitiateFlow(
local func() string,
remote string,
vs []proxy.MacVerifier,
gs []proxy.MacGenerator,
c Congestion,
keepalive time.Duration,
) (*InitiatedFlow, error) {
f := InitiatedFlow{
Local: local,
Remote: remote,
Flow: newFlow(c, vs, gs),
keepalive: keepalive,
}
return &f, nil
}
func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
return Flow{ return Flow{
inboundDatagrams: make(chan []byte), inboundDatagrams: make(chan []byte),
@ -81,102 +48,6 @@ func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow
} }
} }
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
}
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return err
}
f.writer = conn
f.startup = true
// prod the connection once a second until we get an ack, then consider it alive
go func() {
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return
}
for !f.isAlive {
if ctx.Err() != nil {
return
}
p := Packet{
ack: 0,
nack: 0,
seq: seq,
data: proxy.SimplePacket(nil),
}
_ = f.sendPacket(p)
time.Sleep(1 * time.Second)
}
}()
go func() {
_, _ = f.produceInternal(ctx, false)
}()
go f.earlyUpdateLoop(ctx, f.keepalive)
if err := f.readQueuePacket(ctx, conn); err != nil {
return err
}
f.isAlive = true
f.startup = false
go func() {
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
if err := f.readQueuePacket(ctx, conn); err != nil {
log.Println(err)
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
return nil
}
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
}
func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
}
func (f *Flow) IsAlive() bool { func (f *Flow) IsAlive() bool {
return f.isAlive return f.isAlive
} }
@ -265,7 +136,7 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
} }
} }
func (f *Flow) sendPacket(p Packet) error { func (f *Flow) sendPacket(p proxy.Packet) error {
b := p.Marshal() b := p.Marshal()
for _, g := range f.generators { for _, g := range f.generators {
@ -302,13 +173,24 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) {
} }
} }
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000) buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf)
if err != nil { if d, ok := ctx.Deadline(); ok {
return err if err := c.SetReadDeadline(d); err != nil {
return nil, err
}
} }
return f.queueDatagram(ctx, buf[:n]) n, _, err := c.ReadFromUDP(buf)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
if ctx.Err() != nil {
return nil, ctx.Err()
}
}
return nil, err
}
return buf[:n], nil
} }

View File

@ -107,7 +107,9 @@ func TestFlow_Produce(t *testing.T) {
flowA.isAlive = true flowA.isAlive = true
go func() { go func() {
err := flowA.readQueuePacket(context.Background(), testConn.SideB()) p, err := flowA.readPacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
err = flowA.queueDatagram(context.Background(), p)
assert.Nil(t, err) assert.Nil(t, err)
}() }()
p, err := flowA.Produce(context.Background()) p, err := flowA.Produce(context.Background())
@ -143,7 +145,9 @@ func TestFlow_Produce(t *testing.T) {
flowA.isAlive = true flowA.isAlive = true
go func() { go func() {
err := flowA.readQueuePacket(context.Background(), testConn.SideB()) p, err := flowA.readPacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
err = flowA.queueDatagram(context.Background(), p)
assert.Nil(t, err) assert.Nil(t, err)
}() }()
p, err := flowA.Produce(context.Background()) p, err := flowA.Produce(context.Background())

142
udp/inbound_flow.go Normal file
View File

@ -0,0 +1,142 @@
package udp
import (
"context"
"log"
"mpbl3p/proxy"
"sync"
"time"
)
type InboundFlow struct {
inboundDatagrams chan []byte
mu sync.RWMutex
Flow
}
func newInboundFlow(f Flow) (*InboundFlow, error) {
fi := InboundFlow{
inboundDatagrams: make(chan []byte),
Flow: f,
}
return &fi, nil
}
func (f *InboundFlow) queueDatagram(ctx context.Context, p []byte) error {
select {
case f.inboundDatagrams <- p:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *InboundFlow) processPackets(ctx context.Context) {
for {
f.mu.Lock()
var err error
for once := true; err != nil || once; once = false {
if ctx.Err() != nil {
return
}
err = f.handleExchanges(ctx)
if err != nil {
log.Println(err)
}
}
f.mu.Unlock()
var p []byte
select {
case p = <-f.inboundDatagrams:
case <-ctx.Done():
return
}
// TODO: Check if p means redo exchanges
if false {
continue
}
select {
case f.Flow.inboundDatagrams <- p:
case <-ctx.Done():
return
}
}
}
func (f *InboundFlow) handleExchanges(ctx context.Context) error {
var exchanges []proxy.Exchange
if e, ok := f.congestion.(proxy.Exchange); ok {
exchanges = append(exchanges, e)
}
var exchangeData [][]byte
for _, e := range exchanges {
for once := true; !e.Complete() || once; once = false {
if err := func() (err error) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var recv []byte
select {
case recv = <-f.inboundDatagrams:
case <-ctx.Done():
return ctx.Err()
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
recv, err = proxy.StripMac(recv, v)
if err != nil {
return err
}
}
var resp, data []byte
if resp, data, err = e.Handle(ctx, recv); err != nil {
return err
}
if data != nil {
exchangeData = append(exchangeData, data)
}
if resp != nil {
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
}
return nil
}
func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
}
func (f *InboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
}

View File

@ -16,9 +16,7 @@ type ComparableUdpAddress struct {
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
var ip [16]byte var ip [16]byte
for i, b := range []byte(address.IP) { copy(ip[:], address.IP)
ip[i] = b
}
return ComparableUdpAddress{ return ComparableUdpAddress{
IP: ip, IP: ip,
@ -47,12 +45,7 @@ func NewListener(
return err return err
} }
err = pconn.SetWriteBuffer(0) receivedConnections := make(map[ComparableUdpAddress]*InboundFlow)
if err != nil {
panic(err)
}
receivedConnections := make(map[ComparableUdpAddress]*Flow)
go func() { go func() {
for ctx.Err() == nil { for ctx.Err() == nil {
@ -70,10 +63,11 @@ func NewListener(
} }
raddr := fromUdpAddress(*addr) raddr := fromUdpAddress(*addr)
if f, exists := receivedConnections[raddr]; exists { if fi, exists := receivedConnections[raddr]; exists {
log.Println("existing flow. queuing...") log.Println("existing flow. queuing...")
if err := f.queueDatagram(ctx, buf[:n]); err != nil { if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
log.Println("error")
continue
} }
log.Println("queued") log.Println("queued")
continue continue
@ -94,21 +88,28 @@ func NewListener(
f.raddr = addr f.raddr = addr
f.isAlive = true f.isAlive = true
fi, err := newInboundFlow(f)
if err != nil {
log.Println(err)
continue
}
log.Printf("received new udp connection: %v\n", f) log.Printf("received new udp connection: %v\n", f)
go f.earlyUpdateLoop(ctx, 0) go fi.processPackets(ctx)
go fi.earlyUpdateLoop(ctx, 0)
receivedConnections[raddr] = &f receivedConnections[raddr] = fi
if enableConsumers { if enableConsumers {
p.AddConsumer(ctx, &f) p.AddConsumer(ctx, fi)
} }
if enableProducers { if enableProducers {
p.AddProducer(ctx, &f) p.AddProducer(ctx, fi)
} }
log.Println("handling...") log.Println("handling...")
if err := f.queueDatagram(ctx, buf[:n]); err != nil { if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
return return
} }
log.Println("handled") log.Println("handled")

181
udp/outbound_flow.go Normal file
View File

@ -0,0 +1,181 @@
package udp
import (
"context"
"errors"
"fmt"
"log"
"mpbl3p/proxy"
"net"
"sync"
"time"
)
type OutboundFlow struct {
Local func() string
Remote string
g proxy.MacGenerator
keepalive time.Duration
mu sync.RWMutex
Flow
}
func InitiateFlow(
local func() string,
remote string,
vs []proxy.MacVerifier,
gs []proxy.MacGenerator,
c Congestion,
keepalive time.Duration,
) (*OutboundFlow, error) {
f := OutboundFlow{
Local: local,
Remote: remote,
Flow: newFlow(c, vs, gs),
keepalive: keepalive,
}
return &f, nil
}
func (f *OutboundFlow) String() string {
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
}
func (f *OutboundFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.isAlive {
return nil
}
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return err
}
f.writer = conn
// prod the connection once a second until we get an ack, then consider it alive
var exchanges []proxy.Exchange
if e, ok := f.congestion.(proxy.Exchange); ok {
exchanges = append(exchanges, e)
}
var exchangeData [][]byte
for _, e := range exchanges {
i, err := e.Initial(ctx)
if err != nil {
return err
}
if err = f.sendPacket(proxy.SimplePacket(i)); err != nil {
return err
}
for once := true; !e.Complete() || once; once = false {
if err := func() error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var recv []byte
if recv, err = f.readPacket(ctx, conn); err != nil {
return err
}
for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
recv, err = proxy.StripMac(recv, v)
if err != nil {
return err
}
}
var resp, data []byte
if resp, data, err = e.Handle(ctx, recv); err != nil {
return err
}
if data != nil {
exchangeData = append(exchangeData, data)
}
if resp != nil {
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
}
go func() {
for _, d := range exchangeData {
if err := f.queueDatagram(ctx, d); err != nil {
return
}
}
lockedAccept := func() {
f.mu.RLock()
defer f.mu.RUnlock()
var p []byte
if p, err = f.readPacket(ctx, conn); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return
}
log.Println(err)
return
}
if err := f.queueDatagram(ctx, p); err != nil {
return
}
}
for f.isAlive {
log.Println("alive and listening for packets")
lockedAccept()
}
log.Println("no longer alive")
}()
f.isAlive = true
return nil
}
func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p)
}
func (f *OutboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(ctx)
}