This commit is contained in:
parent
2673f28e63
commit
fad829803a
@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mpbl3p/crypto"
|
||||
@ -12,7 +13,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
|
||||
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
|
||||
p := proxy.NewProxy(0)
|
||||
|
||||
var g func() proxy.MacGenerator
|
||||
@ -46,11 +47,11 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
|
||||
for _, peer := range c.Peers {
|
||||
switch peer.Method {
|
||||
case "TCP":
|
||||
if err := buildTcp(p, peer, g, v); err != nil {
|
||||
if err := buildTcp(ctx, p, peer, g, v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "UDP":
|
||||
if err := buildUdp(p, peer, g, v); err != nil {
|
||||
if err := buildUdp(ctx, p, peer, g, v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -59,7 +60,7 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||
var laddr func() string
|
||||
if peer.LocalPort == 0 {
|
||||
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||
@ -74,13 +75,13 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
||||
return err
|
||||
}
|
||||
|
||||
p.AddConsumer(f, g())
|
||||
p.AddProducer(f, v())
|
||||
p.AddConsumer(ctx, f, g())
|
||||
p.AddProducer(ctx, f, v())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := tcp.NewListener(p, laddr(), v, g)
|
||||
err := tcp.NewListener(ctx, p, laddr(), v, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -88,7 +89,7 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||
var laddr func() string
|
||||
if peer.LocalPort == 0 {
|
||||
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||
@ -120,13 +121,13 @@ func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
||||
return err
|
||||
}
|
||||
|
||||
p.AddConsumer(f, g())
|
||||
p.AddProducer(f, v())
|
||||
p.AddConsumer(ctx, f, g())
|
||||
p.AddProducer(ctx, f, v())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := udp.NewListener(p, laddr(), v, g, c)
|
||||
err := udp.NewListener(ctx, p, laddr(), v, g, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
6
main.go
6
main.go
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -132,7 +133,10 @@ FOREGROUND:
|
||||
}()
|
||||
|
||||
log.Println("building config...")
|
||||
p, err := c.Build(t, t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
p, err := c.Build(ctx, t, t)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -1,22 +1,24 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Producer interface {
|
||||
IsAlive() bool
|
||||
Produce(MacVerifier) (Packet, error)
|
||||
Produce(context.Context, MacVerifier) (Packet, error)
|
||||
}
|
||||
|
||||
type Consumer interface {
|
||||
IsAlive() bool
|
||||
Consume(Packet, MacGenerator) error
|
||||
Consume(context.Context, Packet, MacGenerator) error
|
||||
}
|
||||
|
||||
type Reconnectable interface {
|
||||
Reconnect() error
|
||||
Reconnect(context.Context) error
|
||||
}
|
||||
|
||||
type Source interface {
|
||||
@ -65,7 +67,7 @@ func (p Proxy) Start() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
||||
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
|
||||
go func() {
|
||||
_, reconnectable := c.(Reconnectable)
|
||||
|
||||
@ -74,7 +76,11 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
||||
var err error
|
||||
for once := true; err != nil || once; once = false {
|
||||
log.Printf("attempting to connect consumer `%v`\n", c)
|
||||
err = c.(Reconnectable).Reconnect()
|
||||
err = c.(Reconnectable).Reconnect(ctx)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("closed consumer `%v` (context)\n", c)
|
||||
return
|
||||
}
|
||||
if !once {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@ -83,9 +89,19 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
||||
}
|
||||
|
||||
for c.IsAlive() {
|
||||
if err := c.Consume(<-p.proxyChan, g); err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("closed consumer `%v` (context)\n", c)
|
||||
return
|
||||
case packet := <-p.proxyChan:
|
||||
if err := c.Consume(ctx, packet, g); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("closed consumer `%v` (context)\n", c)
|
||||
return
|
||||
}
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -94,7 +110,7 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
||||
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
|
||||
go func() {
|
||||
_, reconnectable := pr.(Reconnectable)
|
||||
|
||||
@ -103,20 +119,37 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
||||
var err error
|
||||
for once := true; err != nil || once; once = false {
|
||||
log.Printf("attempting to connect producer `%v`\n", pr)
|
||||
err = pr.(Reconnectable).Reconnect()
|
||||
err = pr.(Reconnectable).Reconnect(ctx)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("closed producer `%v` (context)\n", pr)
|
||||
return
|
||||
}
|
||||
if !once {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
log.Printf("connected producer `%v`\n", pr)
|
||||
}
|
||||
|
||||
for pr.IsAlive() {
|
||||
if packet, err := pr.Produce(v); err != nil {
|
||||
if packet, err := pr.Produce(ctx, v); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("closed producer `%v` (context)\n", pr)
|
||||
return
|
||||
}
|
||||
log.Println(err)
|
||||
break
|
||||
} else {
|
||||
p.sinkChan <- packet
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("closed producer `%v` (context)\n", pr)
|
||||
return
|
||||
case p.sinkChan <- packet:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
22
tcp/flow.go
22
tcp/flow.go
@ -2,6 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -124,21 +125,21 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
return f.Flow.Consume(p, g)
|
||||
return f.Flow.Consume(ctx, p, g)
|
||||
}
|
||||
|
||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
return f.Flow.Produce(v)
|
||||
return f.Flow.Produce(ctx, v)
|
||||
}
|
||||
|
||||
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||
if !f.isAlive {
|
||||
return shared.ErrDeadConnection
|
||||
}
|
||||
@ -157,11 +158,16 @@ func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
|
||||
copy(prefixedData[4:], data)
|
||||
|
||||
f.toConsume <- prefixedData
|
||||
select {
|
||||
case f.toConsume <- prefixedData:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
if !f.isAlive {
|
||||
return nil, shared.ErrDeadConnection
|
||||
}
|
||||
@ -169,6 +175,8 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
var data []byte
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case data = <-f.produced:
|
||||
case err := <-f.produceErrors:
|
||||
f.isAlive = false
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -19,7 +20,7 @@ func TestFlow_Consume(t *testing.T) {
|
||||
|
||||
flowA := NewFlowConn(testConn.SideA())
|
||||
|
||||
err := flowA.Consume(testPacket, testMac)
|
||||
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||
require.Nil(t, err)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
@ -46,7 +47,7 @@ func TestFlow_Produce(t *testing.T) {
|
||||
_, err := testConn.SideB().Write(testMarshalled)
|
||||
require.Nil(t, err)
|
||||
|
||||
p, err := flowA.Produce(testMac)
|
||||
p, err := flowA.Produce(context.Background(), testMac)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, len(testContent), len(p.Contents()))
|
||||
})
|
||||
@ -59,7 +60,7 @@ func TestFlow_Produce(t *testing.T) {
|
||||
_, err := testConn.SideB().Write(testMarshalled)
|
||||
require.Nil(t, err)
|
||||
|
||||
p, err := flowA.Produce(testMac)
|
||||
p, err := flowA.Produce(context.Background(), testMac)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, testContent, string(p.Contents()))
|
||||
})
|
||||
|
@ -1,12 +1,13 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"mpbl3p/proxy"
|
||||
"net"
|
||||
)
|
||||
|
||||
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
|
||||
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
|
||||
laddr, err := net.ResolveTCPAddr("tcp", local)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -32,8 +33,8 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
||||
|
||||
log.Printf("received new tcp connection: %v\n", f)
|
||||
|
||||
p.AddConsumer(&f, g())
|
||||
p.AddProducer(&f, v())
|
||||
p.AddConsumer(ctx, &f, g())
|
||||
p.AddProducer(ctx, &f, v())
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1,13 +1,16 @@
|
||||
package udp
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Congestion interface {
|
||||
Sequence() uint32
|
||||
Sequence(ctx context.Context) (uint32, error)
|
||||
NextAck() uint32
|
||||
NextNack() uint32
|
||||
|
||||
ReceivedPacket(seq, nack, ack uint32)
|
||||
|
||||
AwaitEarlyUpdate(keepalive time.Duration) uint32
|
||||
AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
@ -94,7 +95,7 @@ func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NewReno) Sequence() uint32 {
|
||||
func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
|
||||
for len(c.inFlight) >= int(c.windowSize) {
|
||||
<-c.windowNotifier
|
||||
}
|
||||
@ -102,7 +103,13 @@ func (c *NewReno) Sequence() uint32 {
|
||||
c.inFlightMu.Lock()
|
||||
defer c.inFlightMu.Unlock()
|
||||
|
||||
s := <-c.sequence
|
||||
var s uint32
|
||||
select {
|
||||
case s = <-c.sequence:
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
|
||||
c.inFlight = append(c.inFlight, flightInfo{
|
||||
@ -111,7 +118,7 @@ func (c *NewReno) Sequence() uint32 {
|
||||
})
|
||||
c.lastSent = t
|
||||
|
||||
return s
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *NewReno) NextAck() uint32 {
|
||||
@ -126,7 +133,7 @@ func (c *NewReno) NextNack() uint32 {
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
||||
func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
|
||||
for {
|
||||
rtt := time.Duration(math.Round(c.rttNanos))
|
||||
time.Sleep(rtt / 2)
|
||||
@ -136,12 +143,12 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
||||
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
|
||||
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
|
||||
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
|
||||
return 0 // no ack needed
|
||||
return 0, nil // no ack needed
|
||||
}
|
||||
|
||||
// CASE 2: No message sent within the keepalive time
|
||||
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
|
||||
return c.Sequence() // require an ack
|
||||
return c.Sequence(ctx) // require an ack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,11 +89,10 @@ func (n *newRenoTest) RunSideA(ctx context.Context) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
seq := n.sideA.AwaitEarlyUpdate(500 * time.Millisecond)
|
||||
if seq != 0 {
|
||||
// skip keepalive
|
||||
// required to ensure AwaitEarlyUpdate terminates
|
||||
@ -123,11 +122,10 @@ func (n *newRenoTest) RunSideB(ctx context.Context) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
seq := n.sideB.AwaitEarlyUpdate(500 * time.Millisecond)
|
||||
if seq != 0 {
|
||||
// skip keepalive
|
||||
// required to ensure AwaitEarlyUpdate terminates
|
||||
@ -162,7 +160,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
// sleep to simulate preparing packet
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideA.Sequence()
|
||||
seq, _ := c.sideA.Sequence(ctx)
|
||||
|
||||
c.aOutbound <- congestionPacket{
|
||||
seq: seq,
|
||||
@ -200,7 +198,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
// sleep to simulate preparing packet
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideA.Sequence()
|
||||
seq, _ := c.sideA.Sequence(ctx)
|
||||
|
||||
if seq == 20 {
|
||||
// Simulate packet loss of sequence 20
|
||||
@ -246,7 +244,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
go func() {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideA.Sequence()
|
||||
seq, _ := c.sideA.Sequence(ctx)
|
||||
|
||||
c.aOutbound <- congestionPacket{
|
||||
seq: seq,
|
||||
@ -261,7 +259,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
go func() {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideB.Sequence()
|
||||
seq, _ := c.sideB.Sequence(ctx)
|
||||
|
||||
c.bOutbound <- congestionPacket{
|
||||
seq: seq,
|
||||
@ -306,7 +304,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
go func() {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideA.Sequence()
|
||||
seq, _ := c.sideA.Sequence(ctx)
|
||||
|
||||
if seq == 9 {
|
||||
// Simulate packet loss of sequence 9
|
||||
@ -326,7 +324,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
||||
go func() {
|
||||
for i := 0; i < numPackets; i++ {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seq := c.sideB.Sequence()
|
||||
seq, _ := c.sideB.Sequence(ctx)
|
||||
|
||||
if seq == 13 {
|
||||
// Simulate packet loss of sequence 13
|
||||
|
@ -1,41 +1,26 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type None struct {
|
||||
sequence chan uint32
|
||||
type None struct {}
|
||||
|
||||
func NewNone() None {
|
||||
return None{}
|
||||
}
|
||||
|
||||
func NewNone() *None {
|
||||
c := None{
|
||||
sequence: make(chan uint32),
|
||||
}
|
||||
|
||||
go func() {
|
||||
var s uint32
|
||||
for {
|
||||
if s == 0 {
|
||||
s++
|
||||
continue
|
||||
}
|
||||
|
||||
c.sequence <- s
|
||||
s++
|
||||
}
|
||||
}()
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c *None) String() string {
|
||||
func (c None) String() string {
|
||||
return fmt.Sprintf("{None}")
|
||||
}
|
||||
|
||||
func (c *None) ReceivedPacket(uint32, uint32, uint32) {}
|
||||
func (c *None) NextNack() uint32 { return 0 }
|
||||
func (c *None) NextAck() uint32 { return 0 }
|
||||
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }
|
||||
func (c *None) Sequence() uint32 { return <-c.sequence }
|
||||
func (c None) ReceivedPacket(uint32, uint32, uint32) {}
|
||||
func (c None) NextNack() uint32 { return 0 }
|
||||
func (c None) NextAck() uint32 { return 0 }
|
||||
func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
|
||||
<-ctx.Done()
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }
|
||||
|
87
udp/flow.go
87
udp/flow.go
@ -1,6 +1,7 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"mpbl3p/proxy"
|
||||
@ -80,7 +81,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *InitiatedFlow) Reconnect() error {
|
||||
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
@ -108,7 +109,10 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
|
||||
// prod the connection once a second until we get an ack, then consider it alive
|
||||
go func() {
|
||||
seq := f.congestion.Sequence()
|
||||
seq, err := f.congestion.Sequence(ctx)
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
|
||||
for !f.isAlive {
|
||||
p := Packet{
|
||||
@ -124,11 +128,11 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, _ = f.produceInternal(f.v, false)
|
||||
_, _ = f.produceInternal(ctx, f.v, false)
|
||||
}()
|
||||
go f.earlyUpdateLoop(f.g, f.keepalive)
|
||||
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
|
||||
|
||||
if err := f.acceptPacket(conn); err != nil {
|
||||
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -140,7 +144,7 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if err := f.acceptPacket(conn); err != nil {
|
||||
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
@ -155,25 +159,25 @@ func (f *InitiatedFlow) Reconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
||||
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
return f.Flow.Consume(p, g)
|
||||
return f.Flow.Consume(ctx, p, g)
|
||||
}
|
||||
|
||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
return f.Flow.Produce(v)
|
||||
return f.Flow.Produce(ctx, v)
|
||||
}
|
||||
|
||||
func (f *Flow) IsAlive() bool {
|
||||
return f.isAlive
|
||||
}
|
||||
|
||||
func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
|
||||
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
|
||||
if !f.isAlive {
|
||||
return shared.ErrDeadConnection
|
||||
}
|
||||
@ -182,32 +186,43 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
|
||||
|
||||
// Sequence is the congestion controllers opportunity to block
|
||||
log.Println("awaiting sequence")
|
||||
p := Packet{
|
||||
seq: f.congestion.Sequence(),
|
||||
data: pp,
|
||||
seq, err := f.congestion.Sequence(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("received sequence")
|
||||
|
||||
// Choose up to date ACK/NACK even after blocking
|
||||
p.ack = f.congestion.NextAck()
|
||||
p.nack = f.congestion.NextNack()
|
||||
p := Packet{
|
||||
seq: seq,
|
||||
data: pp,
|
||||
ack: f.congestion.NextAck(),
|
||||
nack: f.congestion.NextNack(),
|
||||
}
|
||||
|
||||
return f.sendPacket(p, g)
|
||||
}
|
||||
|
||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||
if !f.isAlive {
|
||||
return nil, shared.ErrDeadConnection
|
||||
}
|
||||
|
||||
return f.produceInternal(v, true)
|
||||
return f.produceInternal(ctx, v, true)
|
||||
}
|
||||
|
||||
func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
||||
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
||||
for once := true; mustReturn || once; once = false {
|
||||
log.Println(f.congestion)
|
||||
|
||||
b, err := proxy.StripMac(<-f.inboundDatagrams, v)
|
||||
var received []byte
|
||||
select {
|
||||
case received = <-f.inboundDatagrams:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
b, err := proxy.StripMac(received, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -232,8 +247,13 @@ func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Pack
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *Flow) handleDatagram(p []byte) {
|
||||
f.inboundDatagrams <- p
|
||||
func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
|
||||
select {
|
||||
case f.inboundDatagrams <- p:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
||||
@ -249,27 +269,34 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
|
||||
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
|
||||
for f.isAlive {
|
||||
seq := f.congestion.AwaitEarlyUpdate(keepalive)
|
||||
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
|
||||
if err != nil {
|
||||
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
|
||||
return
|
||||
}
|
||||
p := Packet{
|
||||
ack: f.congestion.NextAck(),
|
||||
nack: f.congestion.NextNack(),
|
||||
seq: seq,
|
||||
data: proxy.SimplePacket(nil),
|
||||
ack: f.congestion.NextAck(),
|
||||
nack: f.congestion.NextNack(),
|
||||
}
|
||||
|
||||
_ = f.sendPacket(p, g)
|
||||
err = f.sendPacket(p, g)
|
||||
if err != nil {
|
||||
fmt.Printf("error sending early update packet: `%v`\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) acceptPacket(c PacketConn) error {
|
||||
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
|
||||
// TODO: Replace 6000 with MTU+header size
|
||||
buf := make([]byte, 6000)
|
||||
n, _, err := c.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.handleDatagram(buf[:n])
|
||||
return nil
|
||||
return f.queueDatagram(ctx, buf[:n])
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -24,7 +25,7 @@ func TestFlow_Consume(t *testing.T) {
|
||||
flowA.writer = testConn.SideB()
|
||||
flowA.isAlive = true
|
||||
|
||||
err := flowA.Consume(testPacket, testMac)
|
||||
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||
require.Nil(t, err)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
@ -63,10 +64,10 @@ func TestFlow_Produce(t *testing.T) {
|
||||
flowA.isAlive = true
|
||||
|
||||
go func() {
|
||||
err := flowA.acceptPacket(testConn.SideB())
|
||||
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
p, err := flowA.Produce(testMac)
|
||||
p, err := flowA.Produce(context.Background(), testMac)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, p.Contents(), len(testContent))
|
||||
|
@ -1,6 +1,7 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"mpbl3p/proxy"
|
||||
"net"
|
||||
@ -25,7 +26,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
||||
}
|
||||
}
|
||||
|
||||
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
|
||||
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -56,10 +57,11 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
||||
|
||||
raddr := fromUdpAddress(*addr)
|
||||
if f, exists := receivedConnections[raddr]; exists {
|
||||
log.Println("existing flow")
|
||||
log.Println("handling...")
|
||||
f.handleDatagram(buf[:n])
|
||||
log.Println("handled")
|
||||
log.Println("existing flow. queuing...")
|
||||
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
||||
|
||||
}
|
||||
log.Println("queued")
|
||||
continue
|
||||
}
|
||||
|
||||
@ -74,15 +76,15 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
||||
|
||||
log.Printf("received new udp connection: %v\n", f)
|
||||
|
||||
go f.earlyUpdateLoop(g, 0)
|
||||
go f.earlyUpdateLoop(ctx, g, 0)
|
||||
|
||||
receivedConnections[raddr] = &f
|
||||
|
||||
p.AddConsumer(&f, g)
|
||||
p.AddProducer(&f, v)
|
||||
p.AddConsumer(ctx, &f, g)
|
||||
p.AddProducer(ctx, &f, v)
|
||||
|
||||
log.Println("handling...")
|
||||
f.handleDatagram(buf[:n])
|
||||
f.queueDatagram(ctx, buf[:n])
|
||||
log.Println("handled")
|
||||
}
|
||||
}()
|
||||
|
Loading…
Reference in New Issue
Block a user