initial context propagation
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Jake Hillion 2021-03-30 20:57:53 +01:00
parent 2673f28e63
commit fad829803a
13 changed files with 200 additions and 129 deletions

View File

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"mpbl3p/crypto" "mpbl3p/crypto"
@ -12,7 +13,7 @@ import (
"time" "time"
) )
func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
p := proxy.NewProxy(0) p := proxy.NewProxy(0)
var g func() proxy.MacGenerator var g func() proxy.MacGenerator
@ -46,11 +47,11 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
for _, peer := range c.Peers { for _, peer := range c.Peers {
switch peer.Method { switch peer.Method {
case "TCP": case "TCP":
if err := buildTcp(p, peer, g, v); err != nil { if err := buildTcp(ctx, p, peer, g, v); err != nil {
return nil, err return nil, err
} }
case "UDP": case "UDP":
if err := buildUdp(p, peer, g, v); err != nil { if err := buildUdp(ctx, p, peer, g, v); err != nil {
return nil, err return nil, err
} }
} }
@ -59,7 +60,7 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
return p, nil return p, nil
} }
func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string var laddr func() string
if peer.LocalPort == 0 { if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -74,13 +75,13 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return err return err
} }
p.AddConsumer(f, g()) p.AddConsumer(ctx, f, g())
p.AddProducer(f, v()) p.AddProducer(ctx, f, v())
return nil return nil
} }
err := tcp.NewListener(p, laddr(), v, g) err := tcp.NewListener(ctx, p, laddr(), v, g)
if err != nil { if err != nil {
return err return err
} }
@ -88,7 +89,7 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return nil return nil
} }
func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string var laddr func() string
if peer.LocalPort == 0 { if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -120,13 +121,13 @@ func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return err return err
} }
p.AddConsumer(f, g()) p.AddConsumer(ctx, f, g())
p.AddProducer(f, v()) p.AddProducer(ctx, f, v())
return nil return nil
} }
err := udp.NewListener(p, laddr(), v, g, c) err := udp.NewListener(ctx, p, laddr(), v, g, c)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -132,7 +133,10 @@ FOREGROUND:
}() }()
log.Println("building config...") log.Println("building config...")
p, err := c.Build(t, t) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p, err := c.Build(ctx, t, t)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -1,22 +1,24 @@
package proxy package proxy
import ( import (
"context"
"errors"
"log" "log"
"time" "time"
) )
type Producer interface { type Producer interface {
IsAlive() bool IsAlive() bool
Produce(MacVerifier) (Packet, error) Produce(context.Context, MacVerifier) (Packet, error)
} }
type Consumer interface { type Consumer interface {
IsAlive() bool IsAlive() bool
Consume(Packet, MacGenerator) error Consume(context.Context, Packet, MacGenerator) error
} }
type Reconnectable interface { type Reconnectable interface {
Reconnect() error Reconnect(context.Context) error
} }
type Source interface { type Source interface {
@ -65,7 +67,7 @@ func (p Proxy) Start() {
}() }()
} }
func (p Proxy) AddConsumer(c Consumer, g MacGenerator) { func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
go func() { go func() {
_, reconnectable := c.(Reconnectable) _, reconnectable := c.(Reconnectable)
@ -74,7 +76,11 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect consumer `%v`\n", c) log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect() err = c.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -83,18 +89,28 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
} }
for c.IsAlive() { for c.IsAlive() {
if err := c.Consume(<-p.proxyChan, g); err != nil { select {
case <-ctx.Done():
log.Printf("closed consumer `%v` (context)\n", c)
return
case packet := <-p.proxyChan:
if err := c.Consume(ctx, packet, g); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
log.Println(err) log.Println(err)
break break
} }
} }
} }
}
log.Printf("closed consumer `%v`\n", c) log.Printf("closed consumer `%v`\n", c)
}() }()
} }
func (p Proxy) AddProducer(pr Producer, v MacVerifier) { func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
go func() { go func() {
_, reconnectable := pr.(Reconnectable) _, reconnectable := pr.(Reconnectable)
@ -103,20 +119,37 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
var err error var err error
for once := true; err != nil || once; once = false { for once := true; err != nil || once; once = false {
log.Printf("attempting to connect producer `%v`\n", pr) log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect() err = pr.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
if !once { if !once {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
if ctx.Err() != nil {
return
}
} }
log.Printf("connected producer `%v`\n", pr) log.Printf("connected producer `%v`\n", pr)
} }
for pr.IsAlive() { for pr.IsAlive() {
if packet, err := pr.Produce(v); err != nil { if packet, err := pr.Produce(ctx, v); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
log.Println(err) log.Println(err)
break break
} else { } else {
p.sinkChan <- packet select {
case <-ctx.Done():
log.Printf("closed producer `%v` (context)\n", pr)
return
case p.sinkChan <- packet:
}
} }
} }
} }

View File

@ -2,6 +2,7 @@ package tcp
import ( import (
"bufio" "bufio"
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@ -124,21 +125,21 @@ func (f *InitiatedFlow) Reconnect() error {
return nil return nil
} }
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Consume(p, g) return f.Flow.Consume(ctx, p, g)
} }
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Produce(v) return f.Flow.Produce(ctx, v)
} }
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error { func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive { if !f.isAlive {
return shared.ErrDeadConnection return shared.ErrDeadConnection
} }
@ -157,11 +158,16 @@ func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
copy(prefixedData[4:], data) copy(prefixedData[4:], data)
f.toConsume <- prefixedData select {
case f.toConsume <- prefixedData:
case <-ctx.Done():
return ctx.Err()
}
return nil return nil
} }
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive { if !f.isAlive {
return nil, shared.ErrDeadConnection return nil, shared.ErrDeadConnection
} }
@ -169,6 +175,8 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
var data []byte var data []byte
select { select {
case <-ctx.Done():
return nil, ctx.Err()
case data = <-f.produced: case data = <-f.produced:
case err := <-f.produceErrors: case err := <-f.produceErrors:
f.isAlive = false f.isAlive = false

View File

@ -1,6 +1,7 @@
package tcp package tcp
import ( import (
"context"
"encoding/binary" "encoding/binary"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -19,7 +20,7 @@ func TestFlow_Consume(t *testing.T) {
flowA := NewFlowConn(testConn.SideA()) flowA := NewFlowConn(testConn.SideA())
err := flowA.Consume(testPacket, testMac) err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err) require.Nil(t, err)
buf := make([]byte, 100) buf := make([]byte, 100)
@ -46,7 +47,7 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(testMac) p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Contents())) assert.Equal(t, len(testContent), len(p.Contents()))
}) })
@ -59,7 +60,7 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(testMac) p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents())) assert.Equal(t, testContent, string(p.Contents()))
}) })

View File

@ -1,12 +1,13 @@
package tcp package tcp
import ( import (
"context"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
"net" "net"
) )
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error { func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
laddr, err := net.ResolveTCPAddr("tcp", local) laddr, err := net.ResolveTCPAddr("tcp", local)
if err != nil { if err != nil {
return err return err
@ -32,8 +33,8 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
log.Printf("received new tcp connection: %v\n", f) log.Printf("received new tcp connection: %v\n", f)
p.AddConsumer(&f, g()) p.AddConsumer(ctx, &f, g())
p.AddProducer(&f, v()) p.AddProducer(ctx, &f, v())
} }
}() }()

View File

@ -1,13 +1,16 @@
package udp package udp
import "time" import (
"context"
"time"
)
type Congestion interface { type Congestion interface {
Sequence() uint32 Sequence(ctx context.Context) (uint32, error)
NextAck() uint32 NextAck() uint32
NextNack() uint32 NextNack() uint32
ReceivedPacket(seq, nack, ack uint32) ReceivedPacket(seq, nack, ack uint32)
AwaitEarlyUpdate(keepalive time.Duration) uint32 AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
} }

View File

@ -1,6 +1,7 @@
package congestion package congestion
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"sort" "sort"
@ -94,7 +95,7 @@ func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
} }
} }
func (c *NewReno) Sequence() uint32 { func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
for len(c.inFlight) >= int(c.windowSize) { for len(c.inFlight) >= int(c.windowSize) {
<-c.windowNotifier <-c.windowNotifier
} }
@ -102,7 +103,13 @@ func (c *NewReno) Sequence() uint32 {
c.inFlightMu.Lock() c.inFlightMu.Lock()
defer c.inFlightMu.Unlock() defer c.inFlightMu.Unlock()
s := <-c.sequence var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return 0, ctx.Err()
}
t := time.Now() t := time.Now()
c.inFlight = append(c.inFlight, flightInfo{ c.inFlight = append(c.inFlight, flightInfo{
@ -111,7 +118,7 @@ func (c *NewReno) Sequence() uint32 {
}) })
c.lastSent = t c.lastSent = t
return s return s, nil
} }
func (c *NewReno) NextAck() uint32 { func (c *NewReno) NextAck() uint32 {
@ -126,7 +133,7 @@ func (c *NewReno) NextNack() uint32 {
return n return n
} }
func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 { func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
for { for {
rtt := time.Duration(math.Round(c.rttNanos)) rtt := time.Duration(math.Round(c.rttNanos))
time.Sleep(rtt / 2) time.Sleep(rtt / 2)
@ -136,12 +143,12 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT // CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs) // this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) { if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
return 0 // no ack needed return 0, nil // no ack needed
} }
// CASE 2: No message sent within the keepalive time // CASE 2: No message sent within the keepalive time
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
return c.Sequence() // require an ack return c.Sequence(ctx) // require an ack
} }
} }
} }

View File

@ -89,11 +89,10 @@ func (n *newRenoTest) RunSideA(ctx context.Context) {
go func() { go func() {
for { for {
if ctx.Err() != nil { seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
if err != nil {
return return
} }
seq := n.sideA.AwaitEarlyUpdate(500 * time.Millisecond)
if seq != 0 { if seq != 0 {
// skip keepalive // skip keepalive
// required to ensure AwaitEarlyUpdate terminates // required to ensure AwaitEarlyUpdate terminates
@ -123,11 +122,10 @@ func (n *newRenoTest) RunSideB(ctx context.Context) {
go func() { go func() {
for { for {
if ctx.Err() != nil { seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
if err != nil {
return return
} }
seq := n.sideB.AwaitEarlyUpdate(500 * time.Millisecond)
if seq != 0 { if seq != 0 {
// skip keepalive // skip keepalive
// required to ensure AwaitEarlyUpdate terminates // required to ensure AwaitEarlyUpdate terminates
@ -162,7 +160,7 @@ func TestNewReno_Congestion(t *testing.T) {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet // sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence() seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{ c.aOutbound <- congestionPacket{
seq: seq, seq: seq,
@ -200,7 +198,7 @@ func TestNewReno_Congestion(t *testing.T) {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet // sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence() seq, _ := c.sideA.Sequence(ctx)
if seq == 20 { if seq == 20 {
// Simulate packet loss of sequence 20 // Simulate packet loss of sequence 20
@ -246,7 +244,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence() seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{ c.aOutbound <- congestionPacket{
seq: seq, seq: seq,
@ -261,7 +259,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideB.Sequence() seq, _ := c.sideB.Sequence(ctx)
c.bOutbound <- congestionPacket{ c.bOutbound <- congestionPacket{
seq: seq, seq: seq,
@ -306,7 +304,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence() seq, _ := c.sideA.Sequence(ctx)
if seq == 9 { if seq == 9 {
// Simulate packet loss of sequence 9 // Simulate packet loss of sequence 9
@ -326,7 +324,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() { go func() {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
seq := c.sideB.Sequence() seq, _ := c.sideB.Sequence(ctx)
if seq == 13 { if seq == 13 {
// Simulate packet loss of sequence 13 // Simulate packet loss of sequence 13

View File

@ -1,41 +1,26 @@
package congestion package congestion
import ( import (
"context"
"fmt" "fmt"
"time" "time"
) )
type None struct { type None struct {}
sequence chan uint32
func NewNone() None {
return None{}
} }
func NewNone() *None { func (c None) String() string {
c := None{
sequence: make(chan uint32),
}
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c
}
func (c *None) String() string {
return fmt.Sprintf("{None}") return fmt.Sprintf("{None}")
} }
func (c *None) ReceivedPacket(uint32, uint32, uint32) {} func (c None) ReceivedPacket(uint32, uint32, uint32) {}
func (c *None) NextNack() uint32 { return 0 } func (c None) NextNack() uint32 { return 0 }
func (c *None) NextAck() uint32 { return 0 } func (c None) NextAck() uint32 { return 0 }
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} } func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
func (c *None) Sequence() uint32 { return <-c.sequence } <-ctx.Done()
return 0, ctx.Err()
}
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }

View File

@ -1,6 +1,7 @@
package udp package udp
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
@ -80,7 +81,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
} }
} }
func (f *InitiatedFlow) Reconnect() error { func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
@ -108,7 +109,10 @@ func (f *InitiatedFlow) Reconnect() error {
// prod the connection once a second until we get an ack, then consider it alive // prod the connection once a second until we get an ack, then consider it alive
go func() { go func() {
seq := f.congestion.Sequence() seq, err := f.congestion.Sequence(ctx)
if err != nil {
}
for !f.isAlive { for !f.isAlive {
p := Packet{ p := Packet{
@ -124,11 +128,11 @@ func (f *InitiatedFlow) Reconnect() error {
}() }()
go func() { go func() {
_, _ = f.produceInternal(f.v, false) _, _ = f.produceInternal(ctx, f.v, false)
}() }()
go f.earlyUpdateLoop(f.g, f.keepalive) go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
if err := f.acceptPacket(conn); err != nil { if err := f.readQueuePacket(ctx, conn); err != nil {
return err return err
} }
@ -140,7 +144,7 @@ func (f *InitiatedFlow) Reconnect() error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
if err := f.acceptPacket(conn); err != nil { if err := f.readQueuePacket(ctx, conn); err != nil {
log.Println(err) log.Println(err)
} }
} }
@ -155,25 +159,25 @@ func (f *InitiatedFlow) Reconnect() error {
return nil return nil
} }
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Consume(p, g) return f.Flow.Consume(ctx, p, g)
} }
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Produce(v) return f.Flow.Produce(ctx, v)
} }
func (f *Flow) IsAlive() bool { func (f *Flow) IsAlive() bool {
return f.isAlive return f.isAlive
} }
func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive { if !f.isAlive {
return shared.ErrDeadConnection return shared.ErrDeadConnection
} }
@ -182,32 +186,43 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
// Sequence is the congestion controllers opportunity to block // Sequence is the congestion controllers opportunity to block
log.Println("awaiting sequence") log.Println("awaiting sequence")
p := Packet{ seq, err := f.congestion.Sequence(ctx)
seq: f.congestion.Sequence(), if err != nil {
data: pp, return err
} }
log.Println("received sequence") log.Println("received sequence")
// Choose up to date ACK/NACK even after blocking // Choose up to date ACK/NACK even after blocking
p.ack = f.congestion.NextAck() p := Packet{
p.nack = f.congestion.NextNack() seq: seq,
data: pp,
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
}
return f.sendPacket(p, g) return f.sendPacket(p, g)
} }
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive { if !f.isAlive {
return nil, shared.ErrDeadConnection return nil, shared.ErrDeadConnection
} }
return f.produceInternal(v, true) return f.produceInternal(ctx, v, true)
} }
func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
for once := true; mustReturn || once; once = false { for once := true; mustReturn || once; once = false {
log.Println(f.congestion) log.Println(f.congestion)
b, err := proxy.StripMac(<-f.inboundDatagrams, v) var received []byte
select {
case received = <-f.inboundDatagrams:
case <-ctx.Done():
return nil, ctx.Err()
}
b, err := proxy.StripMac(received, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -232,8 +247,13 @@ func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Pack
return nil, nil return nil, nil
} }
func (f *Flow) handleDatagram(p []byte) { func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
f.inboundDatagrams <- p select {
case f.inboundDatagrams <- p:
return nil
case <-ctx.Done():
return ctx.Err()
}
} }
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
@ -249,27 +269,34 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
} }
} }
func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) { func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
for f.isAlive { for f.isAlive {
seq := f.congestion.AwaitEarlyUpdate(keepalive) seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
if err != nil {
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
return
}
p := Packet{ p := Packet{
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
seq: seq, seq: seq,
data: proxy.SimplePacket(nil), data: proxy.SimplePacket(nil),
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
} }
_ = f.sendPacket(p, g) err = f.sendPacket(p, g)
if err != nil {
fmt.Printf("error sending early update packet: `%v`\n", err)
}
} }
} }
func (f *Flow) acceptPacket(c PacketConn) error { func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000) buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf) n, _, err := c.ReadFromUDP(buf)
if err != nil { if err != nil {
return err return err
} }
f.handleDatagram(buf[:n]) return f.queueDatagram(ctx, buf[:n])
return nil
} }

View File

@ -1,6 +1,7 @@
package udp package udp
import ( import (
"context"
"fmt" "fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestFlow_Consume(t *testing.T) {
flowA.writer = testConn.SideB() flowA.writer = testConn.SideB()
flowA.isAlive = true flowA.isAlive = true
err := flowA.Consume(testPacket, testMac) err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err) require.Nil(t, err)
buf := make([]byte, 100) buf := make([]byte, 100)
@ -63,10 +64,10 @@ func TestFlow_Produce(t *testing.T) {
flowA.isAlive = true flowA.isAlive = true
go func() { go func() {
err := flowA.acceptPacket(testConn.SideB()) err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err) assert.Nil(t, err)
}() }()
p, err := flowA.Produce(testMac) p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err) require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent)) assert.Len(t, p.Contents(), len(testContent))

View File

@ -1,6 +1,7 @@
package udp package udp
import ( import (
"context"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
"net" "net"
@ -25,7 +26,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
} }
} }
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error { func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
laddr, err := net.ResolveUDPAddr("udp", local) laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil { if err != nil {
return err return err
@ -56,10 +57,11 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
raddr := fromUdpAddress(*addr) raddr := fromUdpAddress(*addr)
if f, exists := receivedConnections[raddr]; exists { if f, exists := receivedConnections[raddr]; exists {
log.Println("existing flow") log.Println("existing flow. queuing...")
log.Println("handling...") if err := f.queueDatagram(ctx, buf[:n]); err != nil {
f.handleDatagram(buf[:n])
log.Println("handled") }
log.Println("queued")
continue continue
} }
@ -74,15 +76,15 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
log.Printf("received new udp connection: %v\n", f) log.Printf("received new udp connection: %v\n", f)
go f.earlyUpdateLoop(g, 0) go f.earlyUpdateLoop(ctx, g, 0)
receivedConnections[raddr] = &f receivedConnections[raddr] = &f
p.AddConsumer(&f, g) p.AddConsumer(ctx, &f, g)
p.AddProducer(&f, v) p.AddProducer(ctx, &f, v)
log.Println("handling...") log.Println("handling...")
f.handleDatagram(buf[:n]) f.queueDatagram(ctx, buf[:n])
log.Println("handled") log.Println("handled")
} }
}() }()