From fad829803aa52b96dc3724ae9bd0c6d1ec8ae3ba Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Tue, 30 Mar 2021 20:57:53 +0100 Subject: [PATCH] initial context propagation --- config/builder.go | 23 ++++----- main.go | 6 ++- proxy/proxy.go | 57 +++++++++++++++++----- tcp/flow.go | 22 ++++++--- tcp/flow_test.go | 7 +-- tcp/listener.go | 7 +-- udp/congestion.go | 9 ++-- udp/congestion/newreno.go | 19 +++++--- udp/congestion/newreno_test.go | 22 ++++----- udp/congestion/none.go | 43 ++++++----------- udp/flow.go | 87 ++++++++++++++++++++++------------ udp/flow_test.go | 7 +-- udp/listener.go | 20 ++++---- 13 files changed, 200 insertions(+), 129 deletions(-) diff --git a/config/builder.go b/config/builder.go index 318073a..b83995c 100644 --- a/config/builder.go +++ b/config/builder.go @@ -1,6 +1,7 @@ package config import ( + "context" "encoding/base64" "fmt" "mpbl3p/crypto" @@ -12,7 +13,7 @@ import ( "time" ) -func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { +func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { p := proxy.NewProxy(0) var g func() proxy.MacGenerator @@ -46,11 +47,11 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy for _, peer := range c.Peers { switch peer.Method { case "TCP": - if err := buildTcp(p, peer, g, v); err != nil { + if err := buildTcp(ctx, p, peer, g, v); err != nil { return nil, err } case "UDP": - if err := buildUdp(p, peer, g, v); err != nil { + if err := buildUdp(ctx, p, peer, g, v); err != nil { return nil, err } } @@ -59,7 +60,7 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy return p, nil } -func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { +func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { var laddr func() string if peer.LocalPort == 0 { laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } @@ -74,13 +75,13 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p return err } - p.AddConsumer(f, g()) - p.AddProducer(f, v()) + p.AddConsumer(ctx, f, g()) + p.AddProducer(ctx, f, v()) return nil } - err := tcp.NewListener(p, laddr(), v, g) + err := tcp.NewListener(ctx, p, laddr(), v, g) if err != nil { return err } @@ -88,7 +89,7 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p return nil } -func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { +func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { var laddr func() string if peer.LocalPort == 0 { laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } @@ -120,13 +121,13 @@ func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p return err } - p.AddConsumer(f, g()) - p.AddProducer(f, v()) + p.AddConsumer(ctx, f, g()) + p.AddProducer(ctx, f, v()) return nil } - err := udp.NewListener(p, laddr(), v, g, c) + err := udp.NewListener(ctx, p, laddr(), v, g, c) if err != nil { return err } diff --git a/main.go b/main.go index c995e04..0c31e25 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "log" @@ -132,7 +133,10 @@ FOREGROUND: }() log.Println("building config...") - p, err := c.Build(t, t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p, err := c.Build(ctx, t, t) if err != nil { panic(err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index a82698b..593d65b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,22 +1,24 @@ package proxy import ( + "context" + "errors" "log" "time" ) type Producer interface { IsAlive() bool - Produce(MacVerifier) (Packet, error) + Produce(context.Context, MacVerifier) (Packet, error) } type Consumer interface { IsAlive() bool - Consume(Packet, MacGenerator) error + Consume(context.Context, Packet, MacGenerator) error } type Reconnectable interface { - Reconnect() error + Reconnect(context.Context) error } type Source interface { @@ -65,7 +67,7 @@ func (p Proxy) Start() { }() } -func (p Proxy) AddConsumer(c Consumer, g MacGenerator) { +func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { go func() { _, reconnectable := c.(Reconnectable) @@ -74,7 +76,11 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) { var err error for once := true; err != nil || once; once = false { log.Printf("attempting to connect consumer `%v`\n", c) - err = c.(Reconnectable).Reconnect() + err = c.(Reconnectable).Reconnect(ctx) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed consumer `%v` (context)\n", c) + return + } if !once { time.Sleep(time.Second) } @@ -83,9 +89,19 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) { } for c.IsAlive() { - if err := c.Consume(<-p.proxyChan, g); err != nil { - log.Println(err) - break + select { + case <-ctx.Done(): + log.Printf("closed consumer `%v` (context)\n", c) + return + case packet := <-p.proxyChan: + if err := c.Consume(ctx, packet, g); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed consumer `%v` (context)\n", c) + return + } + log.Println(err) + break + } } } } @@ -94,7 +110,7 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) { }() } -func (p Proxy) AddProducer(pr Producer, v MacVerifier) { +func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) { go func() { _, reconnectable := pr.(Reconnectable) @@ -103,20 +119,37 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { var err error for once := true; err != nil || once; once = false { log.Printf("attempting to connect producer `%v`\n", pr) - err = pr.(Reconnectable).Reconnect() + err = pr.(Reconnectable).Reconnect(ctx) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed producer `%v` (context)\n", pr) + return + } if !once { time.Sleep(time.Second) } + if ctx.Err() != nil { + return + } + } log.Printf("connected producer `%v`\n", pr) } for pr.IsAlive() { - if packet, err := pr.Produce(v); err != nil { + if packet, err := pr.Produce(ctx, v); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("closed producer `%v` (context)\n", pr) + return + } log.Println(err) break } else { - p.sinkChan <- packet + select { + case <-ctx.Done(): + log.Printf("closed producer `%v` (context)\n", pr) + return + case p.sinkChan <- packet: + } } } } diff --git a/tcp/flow.go b/tcp/flow.go index b24a353..cf59fde 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -2,6 +2,7 @@ package tcp import ( "bufio" + "context" "encoding/binary" "fmt" "io" @@ -124,21 +125,21 @@ func (f *InitiatedFlow) Reconnect() error { return nil } -func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(p, g) + return f.Flow.Consume(ctx, p, g) } -func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { +func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(v) + return f.Flow.Produce(ctx, v) } -func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { if !f.isAlive { return shared.ErrDeadConnection } @@ -157,11 +158,16 @@ func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error { binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) copy(prefixedData[4:], data) - f.toConsume <- prefixedData + select { + case f.toConsume <- prefixedData: + case <-ctx.Done(): + return ctx.Err() + } + return nil } -func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } @@ -169,6 +175,8 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { var data []byte select { + case <-ctx.Done(): + return nil, ctx.Err() case data = <-f.produced: case err := <-f.produceErrors: f.isAlive = false diff --git a/tcp/flow_test.go b/tcp/flow_test.go index f0f3e21..3d0b49c 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -1,6 +1,7 @@ package tcp import ( + "context" "encoding/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,7 +20,7 @@ func TestFlow_Consume(t *testing.T) { flowA := NewFlowConn(testConn.SideA()) - err := flowA.Consume(testPacket, testMac) + err := flowA.Consume(context.Background(), testPacket, testMac) require.Nil(t, err) buf := make([]byte, 100) @@ -46,7 +47,7 @@ func TestFlow_Produce(t *testing.T) { _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(testMac) + p, err := flowA.Produce(context.Background(), testMac) require.Nil(t, err) assert.Equal(t, len(testContent), len(p.Contents())) }) @@ -59,7 +60,7 @@ func TestFlow_Produce(t *testing.T) { _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(testMac) + p, err := flowA.Produce(context.Background(), testMac) require.Nil(t, err) assert.Equal(t, testContent, string(p.Contents())) }) diff --git a/tcp/listener.go b/tcp/listener.go index 7f32a3a..aa4c789 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -1,12 +1,13 @@ package tcp import ( + "context" "log" "mpbl3p/proxy" "net" ) -func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error { +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error { laddr, err := net.ResolveTCPAddr("tcp", local) if err != nil { return err @@ -32,8 +33,8 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun log.Printf("received new tcp connection: %v\n", f) - p.AddConsumer(&f, g()) - p.AddProducer(&f, v()) + p.AddConsumer(ctx, &f, g()) + p.AddProducer(ctx, &f, v()) } }() diff --git a/udp/congestion.go b/udp/congestion.go index ea2f7bb..239645b 100644 --- a/udp/congestion.go +++ b/udp/congestion.go @@ -1,13 +1,16 @@ package udp -import "time" +import ( + "context" + "time" +) type Congestion interface { - Sequence() uint32 + Sequence(ctx context.Context) (uint32, error) NextAck() uint32 NextNack() uint32 ReceivedPacket(seq, nack, ack uint32) - AwaitEarlyUpdate(keepalive time.Duration) uint32 + AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) } diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go index 4e94a41..d796e11 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno.go @@ -1,6 +1,7 @@ package congestion import ( + "context" "fmt" "math" "sort" @@ -94,7 +95,7 @@ func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) { } } -func (c *NewReno) Sequence() uint32 { +func (c *NewReno) Sequence(ctx context.Context) (uint32, error) { for len(c.inFlight) >= int(c.windowSize) { <-c.windowNotifier } @@ -102,7 +103,13 @@ func (c *NewReno) Sequence() uint32 { c.inFlightMu.Lock() defer c.inFlightMu.Unlock() - s := <-c.sequence + var s uint32 + select { + case s = <-c.sequence: + case <-ctx.Done(): + return 0, ctx.Err() + } + t := time.Now() c.inFlight = append(c.inFlight, flightInfo{ @@ -111,7 +118,7 @@ func (c *NewReno) Sequence() uint32 { }) c.lastSent = t - return s + return s, nil } func (c *NewReno) NextAck() uint32 { @@ -126,7 +133,7 @@ func (c *NewReno) NextNack() uint32 { return n } -func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 { +func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) { for { rtt := time.Duration(math.Round(c.rttNanos)) time.Sleep(rtt / 2) @@ -136,12 +143,12 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 { // CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT // this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs) if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) { - return 0 // no ack needed + return 0, nil // no ack needed } // CASE 2: No message sent within the keepalive time if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { - return c.Sequence() // require an ack + return c.Sequence(ctx) // require an ack } } } diff --git a/udp/congestion/newreno_test.go b/udp/congestion/newreno_test.go index ed3b235..f9f572a 100644 --- a/udp/congestion/newreno_test.go +++ b/udp/congestion/newreno_test.go @@ -89,11 +89,10 @@ func (n *newRenoTest) RunSideA(ctx context.Context) { go func() { for { - if ctx.Err() != nil { + seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500 * time.Millisecond) + if err != nil { return } - - seq := n.sideA.AwaitEarlyUpdate(500 * time.Millisecond) if seq != 0 { // skip keepalive // required to ensure AwaitEarlyUpdate terminates @@ -123,11 +122,10 @@ func (n *newRenoTest) RunSideB(ctx context.Context) { go func() { for { - if ctx.Err() != nil { + seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500 * time.Millisecond) + if err != nil { return } - - seq := n.sideB.AwaitEarlyUpdate(500 * time.Millisecond) if seq != 0 { // skip keepalive // required to ensure AwaitEarlyUpdate terminates @@ -162,7 +160,7 @@ func TestNewReno_Congestion(t *testing.T) { for i := 0; i < numPackets; i++ { // sleep to simulate preparing packet time.Sleep(1 * time.Millisecond) - seq := c.sideA.Sequence() + seq, _ := c.sideA.Sequence(ctx) c.aOutbound <- congestionPacket{ seq: seq, @@ -200,7 +198,7 @@ func TestNewReno_Congestion(t *testing.T) { for i := 0; i < numPackets; i++ { // sleep to simulate preparing packet time.Sleep(1 * time.Millisecond) - seq := c.sideA.Sequence() + seq, _ := c.sideA.Sequence(ctx) if seq == 20 { // Simulate packet loss of sequence 20 @@ -246,7 +244,7 @@ func TestNewReno_Congestion(t *testing.T) { go func() { for i := 0; i < numPackets; i++ { time.Sleep(1 * time.Millisecond) - seq := c.sideA.Sequence() + seq, _ := c.sideA.Sequence(ctx) c.aOutbound <- congestionPacket{ seq: seq, @@ -261,7 +259,7 @@ func TestNewReno_Congestion(t *testing.T) { go func() { for i := 0; i < numPackets; i++ { time.Sleep(1 * time.Millisecond) - seq := c.sideB.Sequence() + seq, _ := c.sideB.Sequence(ctx) c.bOutbound <- congestionPacket{ seq: seq, @@ -306,7 +304,7 @@ func TestNewReno_Congestion(t *testing.T) { go func() { for i := 0; i < numPackets; i++ { time.Sleep(1 * time.Millisecond) - seq := c.sideA.Sequence() + seq, _ := c.sideA.Sequence(ctx) if seq == 9 { // Simulate packet loss of sequence 9 @@ -326,7 +324,7 @@ func TestNewReno_Congestion(t *testing.T) { go func() { for i := 0; i < numPackets; i++ { time.Sleep(1 * time.Millisecond) - seq := c.sideB.Sequence() + seq, _ := c.sideB.Sequence(ctx) if seq == 13 { // Simulate packet loss of sequence 13 diff --git a/udp/congestion/none.go b/udp/congestion/none.go index be7564f..b7ef6a6 100644 --- a/udp/congestion/none.go +++ b/udp/congestion/none.go @@ -1,41 +1,26 @@ package congestion import ( + "context" "fmt" "time" ) -type None struct { - sequence chan uint32 +type None struct {} + +func NewNone() None { + return None{} } -func NewNone() *None { - c := None{ - sequence: make(chan uint32), - } - - go func() { - var s uint32 - for { - if s == 0 { - s++ - continue - } - - c.sequence <- s - s++ - } - }() - - return &c -} - -func (c *None) String() string { +func (c None) String() string { return fmt.Sprintf("{None}") } -func (c *None) ReceivedPacket(uint32, uint32, uint32) {} -func (c *None) NextNack() uint32 { return 0 } -func (c *None) NextAck() uint32 { return 0 } -func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} } -func (c *None) Sequence() uint32 { return <-c.sequence } +func (c None) ReceivedPacket(uint32, uint32, uint32) {} +func (c None) NextNack() uint32 { return 0 } +func (c None) NextAck() uint32 { return 0 } +func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) { + <-ctx.Done() + return 0, ctx.Err() +} +func (c None) Sequence(context.Context) (uint32, error) { return 0, nil } diff --git a/udp/flow.go b/udp/flow.go index 6fd7219..81b67d3 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -1,6 +1,7 @@ package udp import ( + "context" "fmt" "log" "mpbl3p/proxy" @@ -80,7 +81,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow { } } -func (f *InitiatedFlow) Reconnect() error { +func (f *InitiatedFlow) Reconnect(ctx context.Context) error { f.mu.Lock() defer f.mu.Unlock() @@ -108,7 +109,10 @@ func (f *InitiatedFlow) Reconnect() error { // prod the connection once a second until we get an ack, then consider it alive go func() { - seq := f.congestion.Sequence() + seq, err := f.congestion.Sequence(ctx) + if err != nil { + + } for !f.isAlive { p := Packet{ @@ -124,11 +128,11 @@ func (f *InitiatedFlow) Reconnect() error { }() go func() { - _, _ = f.produceInternal(f.v, false) + _, _ = f.produceInternal(ctx, f.v, false) }() - go f.earlyUpdateLoop(f.g, f.keepalive) + go f.earlyUpdateLoop(ctx, f.g, f.keepalive) - if err := f.acceptPacket(conn); err != nil { + if err := f.readQueuePacket(ctx, conn); err != nil { return err } @@ -140,7 +144,7 @@ func (f *InitiatedFlow) Reconnect() error { f.mu.RLock() defer f.mu.RUnlock() - if err := f.acceptPacket(conn); err != nil { + if err := f.readQueuePacket(ctx, conn); err != nil { log.Println(err) } } @@ -155,25 +159,25 @@ func (f *InitiatedFlow) Reconnect() error { return nil } -func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(p, g) + return f.Flow.Consume(ctx, p, g) } -func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { +func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(v) + return f.Flow.Produce(ctx, v) } func (f *Flow) IsAlive() bool { return f.isAlive } -func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error { if !f.isAlive { return shared.ErrDeadConnection } @@ -182,32 +186,43 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { // Sequence is the congestion controllers opportunity to block log.Println("awaiting sequence") - p := Packet{ - seq: f.congestion.Sequence(), - data: pp, + seq, err := f.congestion.Sequence(ctx) + if err != nil { + return err } log.Println("received sequence") // Choose up to date ACK/NACK even after blocking - p.ack = f.congestion.NextAck() - p.nack = f.congestion.NextNack() + p := Packet{ + seq: seq, + data: pp, + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), + } return f.sendPacket(p, g) } -func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } - return f.produceInternal(v, true) + return f.produceInternal(ctx, v, true) } -func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { +func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { for once := true; mustReturn || once; once = false { log.Println(f.congestion) - b, err := proxy.StripMac(<-f.inboundDatagrams, v) + var received []byte + select { + case received = <-f.inboundDatagrams: + case <-ctx.Done(): + return nil, ctx.Err() + } + + b, err := proxy.StripMac(received, v) if err != nil { return nil, err } @@ -232,8 +247,13 @@ func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Pack return nil, nil } -func (f *Flow) handleDatagram(p []byte) { - f.inboundDatagrams <- p +func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { + select { + case f.inboundDatagrams <- p: + return nil + case <-ctx.Done(): + return ctx.Err() + } } func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { @@ -249,27 +269,34 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { } } -func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) { +func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) { for f.isAlive { - seq := f.congestion.AwaitEarlyUpdate(keepalive) + seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive) + if err != nil { + fmt.Printf("terminating earlyupdateloop for `%v`\n", f) + return + } p := Packet{ - ack: f.congestion.NextAck(), - nack: f.congestion.NextNack(), seq: seq, data: proxy.SimplePacket(nil), + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), } - _ = f.sendPacket(p, g) + err = f.sendPacket(p, g) + if err != nil { + fmt.Printf("error sending early update packet: `%v`\n", err) + } } } -func (f *Flow) acceptPacket(c PacketConn) error { +func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { + // TODO: Replace 6000 with MTU+header size buf := make([]byte, 6000) n, _, err := c.ReadFromUDP(buf) if err != nil { return err } - f.handleDatagram(buf[:n]) - return nil + return f.queueDatagram(ctx, buf[:n]) } diff --git a/udp/flow_test.go b/udp/flow_test.go index b2f20c2..d044477 100644 --- a/udp/flow_test.go +++ b/udp/flow_test.go @@ -1,6 +1,7 @@ package udp import ( + "context" "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,7 +25,7 @@ func TestFlow_Consume(t *testing.T) { flowA.writer = testConn.SideB() flowA.isAlive = true - err := flowA.Consume(testPacket, testMac) + err := flowA.Consume(context.Background(), testPacket, testMac) require.Nil(t, err) buf := make([]byte, 100) @@ -63,10 +64,10 @@ func TestFlow_Produce(t *testing.T) { flowA.isAlive = true go func() { - err := flowA.acceptPacket(testConn.SideB()) + err := flowA.readQueuePacket(context.Background(), testConn.SideB()) assert.Nil(t, err) }() - p, err := flowA.Produce(testMac) + p, err := flowA.Produce(context.Background(), testMac) require.Nil(t, err) assert.Len(t, p.Contents(), len(testContent)) diff --git a/udp/listener.go b/udp/listener.go index 09cc19a..f2b9e0d 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -1,6 +1,7 @@ package udp import ( + "context" "log" "mpbl3p/proxy" "net" @@ -25,7 +26,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { } } -func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error { +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error { laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return err @@ -56,10 +57,11 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun raddr := fromUdpAddress(*addr) if f, exists := receivedConnections[raddr]; exists { - log.Println("existing flow") - log.Println("handling...") - f.handleDatagram(buf[:n]) - log.Println("handled") + log.Println("existing flow. queuing...") + if err := f.queueDatagram(ctx, buf[:n]); err != nil { + + } + log.Println("queued") continue } @@ -74,15 +76,15 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun log.Printf("received new udp connection: %v\n", f) - go f.earlyUpdateLoop(g, 0) + go f.earlyUpdateLoop(ctx, g, 0) receivedConnections[raddr] = &f - p.AddConsumer(&f, g) - p.AddProducer(&f, v) + p.AddConsumer(ctx, &f, g) + p.AddProducer(ctx, &f, v) log.Println("handling...") - f.handleDatagram(buf[:n]) + f.queueDatagram(ctx, buf[:n]) log.Println("handled") } }()