cleanexit #18

Merged
JakeHillion merged 6 commits from cleanexit into develop 2021-04-06 15:53:04 +01:00
14 changed files with 221 additions and 135 deletions

View File

@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/base64"
"fmt"
"mpbl3p/crypto"
@ -12,7 +13,7 @@ import (
"time"
)
func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
p := proxy.NewProxy(0)
var g func() proxy.MacGenerator
@ -46,11 +47,11 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
for _, peer := range c.Peers {
switch peer.Method {
case "TCP":
if err := buildTcp(p, peer, g, v); err != nil {
if err := buildTcp(ctx, p, peer, g, v); err != nil {
return nil, err
}
case "UDP":
if err := buildUdp(p, peer, g, v); err != nil {
if err := buildUdp(ctx, p, peer, g, v); err != nil {
return nil, err
}
}
@ -59,7 +60,7 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
return p, nil
}
func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -74,13 +75,13 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return err
}
p.AddConsumer(f, g())
p.AddProducer(f, v())
p.AddConsumer(ctx, f, g())
p.AddProducer(ctx, f, v())
return nil
}
err := tcp.NewListener(p, laddr(), v, g)
err := tcp.NewListener(ctx, p, laddr(), v, g)
if err != nil {
return err
}
@ -88,7 +89,7 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return nil
}
func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
var laddr func() string
if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -120,13 +121,13 @@ func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
return err
}
p.AddConsumer(f, g())
p.AddProducer(f, v())
p.AddConsumer(ctx, f, g())
p.AddProducer(ctx, f, v())
return nil
}
err := udp.NewListener(p, laddr(), v, g, c)
err := udp.NewListener(ctx, p, laddr(), v, g, c)
if err != nil {
return err
}

View File

@ -12,7 +12,7 @@ var v = validator.New()
func init() {
if err := v.RegisterValidation("iface", func(fl validator.FieldLevel) bool {
name, ok := fl.Field().Interface().(string)
if ok {
if ok && name != "" {
ifaces, err := net.Interfaces()
if err != nil {
log.Printf("error getting interfaces: %v", err)
@ -60,6 +60,10 @@ type Peer struct {
}
func (p Peer) GetLocalHost() string {
if p.LocalHost == "" {
return ""
}
if err := v.Var(p.LocalHost, "ip"); err == nil {
return p.LocalHost
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"errors"
"fmt"
"log"
@ -132,7 +133,10 @@ FOREGROUND:
}()
log.Println("building config...")
p, err := c.Build(t, t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p, err := c.Build(ctx, t, t)
if err != nil {
panic(err)
}

View File

@ -1,22 +1,24 @@
package proxy
import (
"context"
"errors"
"log"
"time"
)
type Producer interface {
IsAlive() bool
Produce(MacVerifier) (Packet, error)
Produce(context.Context, MacVerifier) (Packet, error)
}
type Consumer interface {
IsAlive() bool
Consume(Packet, MacGenerator) error
Consume(context.Context, Packet, MacGenerator) error
}
type Reconnectable interface {
Reconnect() error
Reconnect(context.Context) error
}
type Source interface {
@ -65,7 +67,7 @@ func (p Proxy) Start() {
}()
}
func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
go func() {
_, reconnectable := c.(Reconnectable)
@ -74,7 +76,11 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
var err error
for once := true; err != nil || once; once = false {
log.Printf("attempting to connect consumer `%v`\n", c)
err = c.(Reconnectable).Reconnect()
err = c.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
if !once {
time.Sleep(time.Second)
}
@ -83,9 +89,19 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
}
for c.IsAlive() {
if err := c.Consume(<-p.proxyChan, g); err != nil {
log.Println(err)
break
select {
case <-ctx.Done():
log.Printf("closed consumer `%v` (context)\n", c)
return
case packet := <-p.proxyChan:
if err := c.Consume(ctx, packet, g); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c)
return
}
log.Println(err)
break
}
}
}
}
@ -94,7 +110,7 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
}()
}
func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
go func() {
_, reconnectable := pr.(Reconnectable)
@ -103,20 +119,37 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
var err error
for once := true; err != nil || once; once = false {
log.Printf("attempting to connect producer `%v`\n", pr)
err = pr.(Reconnectable).Reconnect()
err = pr.(Reconnectable).Reconnect(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
if !once {
time.Sleep(time.Second)
}
if ctx.Err() != nil {
return
}
}
log.Printf("connected producer `%v`\n", pr)
}
for pr.IsAlive() {
if packet, err := pr.Produce(v); err != nil {
if packet, err := pr.Produce(ctx, v); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr)
return
}
log.Println(err)
break
} else {
p.sinkChan <- packet
select {
case <-ctx.Done():
log.Printf("closed producer `%v` (context)\n", pr)
return
case p.sinkChan <- packet:
}
}
}
}

View File

@ -2,6 +2,7 @@ package tcp
import (
"bufio"
"context"
"encoding/binary"
"fmt"
"io"
@ -124,21 +125,21 @@ func (f *InitiatedFlow) Reconnect() error {
return nil
}
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(p, g)
return f.Flow.Consume(ctx, p, g)
}
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(v)
return f.Flow.Produce(ctx, v)
}
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive {
return shared.ErrDeadConnection
}
@ -157,11 +158,16 @@ func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
copy(prefixedData[4:], data)
f.toConsume <- prefixedData
select {
case f.toConsume <- prefixedData:
case <-ctx.Done():
return ctx.Err()
}
return nil
}
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
@ -169,6 +175,8 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
var data []byte
select {
case <-ctx.Done():
return nil, ctx.Err()
case data = <-f.produced:
case err := <-f.produceErrors:
f.isAlive = false

View File

@ -1,6 +1,7 @@
package tcp
import (
"context"
"encoding/binary"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -19,7 +20,7 @@ func TestFlow_Consume(t *testing.T) {
flowA := NewFlowConn(testConn.SideA())
err := flowA.Consume(testPacket, testMac)
err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
@ -46,7 +47,7 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(testMac)
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Contents()))
})
@ -59,7 +60,7 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(testMac)
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents()))
})

View File

@ -1,12 +1,13 @@
package tcp
import (
"context"
"log"
"mpbl3p/proxy"
"net"
)
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
laddr, err := net.ResolveTCPAddr("tcp", local)
if err != nil {
return err
@ -32,8 +33,8 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
log.Printf("received new tcp connection: %v\n", f)
p.AddConsumer(&f, g())
p.AddProducer(&f, v())
p.AddConsumer(ctx, &f, g())
p.AddProducer(ctx, &f, v())
}
}()

View File

@ -1,13 +1,16 @@
package udp
import "time"
import (
"context"
"time"
)
type Congestion interface {
Sequence() uint32
Sequence(ctx context.Context) (uint32, error)
NextAck() uint32
NextNack() uint32
ReceivedPacket(seq, nack, ack uint32)
AwaitEarlyUpdate(keepalive time.Duration) uint32
AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
}

View File

@ -1,6 +1,7 @@
package congestion
import (
"context"
"fmt"
"math"
"sort"
@ -94,7 +95,7 @@ func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
}
}
func (c *NewReno) Sequence() uint32 {
func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
for len(c.inFlight) >= int(c.windowSize) {
<-c.windowNotifier
}
@ -102,7 +103,13 @@ func (c *NewReno) Sequence() uint32 {
c.inFlightMu.Lock()
defer c.inFlightMu.Unlock()
s := <-c.sequence
var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return 0, ctx.Err()
}
t := time.Now()
c.inFlight = append(c.inFlight, flightInfo{
@ -111,7 +118,7 @@ func (c *NewReno) Sequence() uint32 {
})
c.lastSent = t
return s
return s, nil
}
func (c *NewReno) NextAck() uint32 {
@ -126,7 +133,7 @@ func (c *NewReno) NextNack() uint32 {
return n
}
func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
for {
rtt := time.Duration(math.Round(c.rttNanos))
time.Sleep(rtt / 2)
@ -136,12 +143,12 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
return 0 // no ack needed
return 0, nil // no ack needed
}
// CASE 2: No message sent within the keepalive time
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
return c.Sequence() // require an ack
return c.Sequence(ctx) // require an ack
}
}
}

View File

@ -89,11 +89,10 @@ func (n *newRenoTest) RunSideA(ctx context.Context) {
go func() {
for {
if ctx.Err() != nil {
seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
if err != nil {
return
}
seq := n.sideA.AwaitEarlyUpdate(500 * time.Millisecond)
if seq != 0 {
// skip keepalive
// required to ensure AwaitEarlyUpdate terminates
@ -123,11 +122,10 @@ func (n *newRenoTest) RunSideB(ctx context.Context) {
go func() {
for {
if ctx.Err() != nil {
seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500*time.Millisecond)
if err != nil {
return
}
seq := n.sideB.AwaitEarlyUpdate(500 * time.Millisecond)
if seq != 0 {
// skip keepalive
// required to ensure AwaitEarlyUpdate terminates
@ -162,7 +160,7 @@ func TestNewReno_Congestion(t *testing.T) {
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence()
seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{
seq: seq,
@ -200,7 +198,7 @@ func TestNewReno_Congestion(t *testing.T) {
for i := 0; i < numPackets; i++ {
// sleep to simulate preparing packet
time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence()
seq, _ := c.sideA.Sequence(ctx)
if seq == 20 {
// Simulate packet loss of sequence 20
@ -246,7 +244,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence()
seq, _ := c.sideA.Sequence(ctx)
c.aOutbound <- congestionPacket{
seq: seq,
@ -261,7 +259,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq := c.sideB.Sequence()
seq, _ := c.sideB.Sequence(ctx)
c.bOutbound <- congestionPacket{
seq: seq,
@ -306,7 +304,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq := c.sideA.Sequence()
seq, _ := c.sideA.Sequence(ctx)
if seq == 9 {
// Simulate packet loss of sequence 9
@ -326,7 +324,7 @@ func TestNewReno_Congestion(t *testing.T) {
go func() {
for i := 0; i < numPackets; i++ {
time.Sleep(1 * time.Millisecond)
seq := c.sideB.Sequence()
seq, _ := c.sideB.Sequence(ctx)
if seq == 13 {
// Simulate packet loss of sequence 13

View File

@ -1,41 +1,26 @@
package congestion
import (
"context"
"fmt"
"time"
)
type None struct {
sequence chan uint32
type None struct{}
func NewNone() None {
return None{}
}
func NewNone() *None {
c := None{
sequence: make(chan uint32),
}
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c
}
func (c *None) String() string {
func (c None) String() string {
return fmt.Sprintf("{None}")
}
func (c *None) ReceivedPacket(uint32, uint32, uint32) {}
func (c *None) NextNack() uint32 { return 0 }
func (c *None) NextAck() uint32 { return 0 }
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }
func (c *None) Sequence() uint32 { return <-c.sequence }
func (c None) ReceivedPacket(uint32, uint32, uint32) {}
func (c None) NextNack() uint32 { return 0 }
func (c None) NextAck() uint32 { return 0 }
func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
<-ctx.Done()
return 0, ctx.Err()
}
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }

View File

@ -1,6 +1,7 @@
package udp
import (
"context"
"fmt"
"log"
"mpbl3p/proxy"
@ -80,7 +81,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
}
}
func (f *InitiatedFlow) Reconnect() error {
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
f.mu.Lock()
defer f.mu.Unlock()
@ -108,9 +109,16 @@ func (f *InitiatedFlow) Reconnect() error {
// prod the connection once a second until we get an ack, then consider it alive
go func() {
seq := f.congestion.Sequence()
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return
}
for !f.isAlive {
if ctx.Err() != nil {
return
}
p := Packet{
ack: 0,
nack: 0,
@ -124,11 +132,11 @@ func (f *InitiatedFlow) Reconnect() error {
}()
go func() {
_, _ = f.produceInternal(f.v, false)
_, _ = f.produceInternal(ctx, f.v, false)
}()
go f.earlyUpdateLoop(f.g, f.keepalive)
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
if err := f.acceptPacket(conn); err != nil {
if err := f.readQueuePacket(ctx, conn); err != nil {
return err
}
@ -140,7 +148,7 @@ func (f *InitiatedFlow) Reconnect() error {
f.mu.RLock()
defer f.mu.RUnlock()
if err := f.acceptPacket(conn); err != nil {
if err := f.readQueuePacket(ctx, conn); err != nil {
log.Println(err)
}
}
@ -155,25 +163,25 @@ func (f *InitiatedFlow) Reconnect() error {
return nil
}
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Consume(p, g)
return f.Flow.Consume(ctx, p, g)
}
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.Flow.Produce(v)
return f.Flow.Produce(ctx, v)
}
func (f *Flow) IsAlive() bool {
return f.isAlive
}
func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
if !f.isAlive {
return shared.ErrDeadConnection
}
@ -182,32 +190,43 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
// Sequence is the congestion controllers opportunity to block
log.Println("awaiting sequence")
p := Packet{
seq: f.congestion.Sequence(),
data: pp,
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return err
}
log.Println("received sequence")
// Choose up to date ACK/NACK even after blocking
p.ack = f.congestion.NextAck()
p.nack = f.congestion.NextNack()
p := Packet{
seq: seq,
data: pp,
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
}
return f.sendPacket(p, g)
}
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
if !f.isAlive {
return nil, shared.ErrDeadConnection
}
return f.produceInternal(v, true)
return f.produceInternal(ctx, v, true)
}
func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
for once := true; mustReturn || once; once = false {
log.Println(f.congestion)
b, err := proxy.StripMac(<-f.inboundDatagrams, v)
var received []byte
select {
case received = <-f.inboundDatagrams:
case <-ctx.Done():
return nil, ctx.Err()
}
b, err := proxy.StripMac(received, v)
if err != nil {
return nil, err
}
@ -232,8 +251,13 @@ func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Pack
return nil, nil
}
func (f *Flow) handleDatagram(p []byte) {
f.inboundDatagrams <- p
func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
select {
case f.inboundDatagrams <- p:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
@ -249,27 +273,34 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
}
}
func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
for f.isAlive {
seq := f.congestion.AwaitEarlyUpdate(keepalive)
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
if err != nil {
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
return
}
p := Packet{
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
seq: seq,
data: proxy.SimplePacket(nil),
ack: f.congestion.NextAck(),
nack: f.congestion.NextNack(),
}
_ = f.sendPacket(p, g)
err = f.sendPacket(p, g)
if err != nil {
fmt.Printf("error sending early update packet: `%v`\n", err)
}
}
}
func (f *Flow) acceptPacket(c PacketConn) error {
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf)
if err != nil {
return err
}
f.handleDatagram(buf[:n])
return nil
return f.queueDatagram(ctx, buf[:n])
}

View File

@ -1,6 +1,7 @@
package udp
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestFlow_Consume(t *testing.T) {
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(testPacket, testMac)
err := flowA.Consume(context.Background(), testPacket, testMac)
require.Nil(t, err)
buf := make([]byte, 100)
@ -63,10 +64,10 @@ func TestFlow_Produce(t *testing.T) {
flowA.isAlive = true
go func() {
err := flowA.acceptPacket(testConn.SideB())
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
}()
p, err := flowA.Produce(testMac)
p, err := flowA.Produce(context.Background(), testMac)
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))

View File

@ -1,9 +1,11 @@
package udp
import (
"context"
"log"
"mpbl3p/proxy"
"net"
"time"
)
type ComparableUdpAddress struct {
@ -25,7 +27,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
}
}
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
@ -44,22 +46,27 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
receivedConnections := make(map[ComparableUdpAddress]*Flow)
go func() {
for {
for ctx.Err() == nil {
buf := make([]byte, 6000)
log.Println("listening...")
n, addr, err := pconn.ReadFromUDP(buf)
if err != nil {
if err := pconn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
panic(err)
}
n, addr, err := pconn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
panic(err)
}
log.Println("listened")
raddr := fromUdpAddress(*addr)
if f, exists := receivedConnections[raddr]; exists {
log.Println("existing flow")
log.Println("handling...")
f.handleDatagram(buf[:n])
log.Println("handled")
log.Println("existing flow. queuing...")
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
}
log.Println("queued")
continue
}
@ -74,15 +81,17 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
log.Printf("received new udp connection: %v\n", f)
go f.earlyUpdateLoop(g, 0)
go f.earlyUpdateLoop(ctx, g, 0)
receivedConnections[raddr] = &f
p.AddConsumer(&f, g)
p.AddProducer(&f, v)
p.AddConsumer(ctx, &f, g)
p.AddProducer(ctx, &f, v)
log.Println("handling...")
f.handleDatagram(buf[:n])
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
return
}
log.Println("handled")
}
}()