This commit is contained in:
parent
2673f28e63
commit
fad829803a
@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"mpbl3p/crypto"
|
"mpbl3p/crypto"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
|
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
|
||||||
p := proxy.NewProxy(0)
|
p := proxy.NewProxy(0)
|
||||||
|
|
||||||
var g func() proxy.MacGenerator
|
var g func() proxy.MacGenerator
|
||||||
@ -46,11 +47,11 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
|
|||||||
for _, peer := range c.Peers {
|
for _, peer := range c.Peers {
|
||||||
switch peer.Method {
|
switch peer.Method {
|
||||||
case "TCP":
|
case "TCP":
|
||||||
if err := buildTcp(p, peer, g, v); err != nil {
|
if err := buildTcp(ctx, p, peer, g, v); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "UDP":
|
case "UDP":
|
||||||
if err := buildUdp(p, peer, g, v); err != nil {
|
if err := buildUdp(ctx, p, peer, g, v); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,7 +60,7 @@ func (c Configuration) Build(source proxy.Source, sink proxy.Sink) (*proxy.Proxy
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||||
var laddr func() string
|
var laddr func() string
|
||||||
if peer.LocalPort == 0 {
|
if peer.LocalPort == 0 {
|
||||||
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||||
@ -74,13 +75,13 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.AddConsumer(f, g())
|
p.AddConsumer(ctx, f, g())
|
||||||
p.AddProducer(f, v())
|
p.AddProducer(ctx, f, v())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := tcp.NewListener(p, laddr(), v, g)
|
err := tcp.NewListener(ctx, p, laddr(), v, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -88,7 +89,7 @@ func buildTcp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error {
|
||||||
var laddr func() string
|
var laddr func() string
|
||||||
if peer.LocalPort == 0 {
|
if peer.LocalPort == 0 {
|
||||||
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
|
||||||
@ -120,13 +121,13 @@ func buildUdp(p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() p
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.AddConsumer(f, g())
|
p.AddConsumer(ctx, f, g())
|
||||||
p.AddProducer(f, v())
|
p.AddProducer(ctx, f, v())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := udp.NewListener(p, laddr(), v, g, c)
|
err := udp.NewListener(ctx, p, laddr(), v, g, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
6
main.go
6
main.go
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -132,7 +133,10 @@ FOREGROUND:
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Println("building config...")
|
log.Println("building config...")
|
||||||
p, err := c.Build(t, t)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p, err := c.Build(ctx, t, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,24 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Producer interface {
|
type Producer interface {
|
||||||
IsAlive() bool
|
IsAlive() bool
|
||||||
Produce(MacVerifier) (Packet, error)
|
Produce(context.Context, MacVerifier) (Packet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Consumer interface {
|
type Consumer interface {
|
||||||
IsAlive() bool
|
IsAlive() bool
|
||||||
Consume(Packet, MacGenerator) error
|
Consume(context.Context, Packet, MacGenerator) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reconnectable interface {
|
type Reconnectable interface {
|
||||||
Reconnect() error
|
Reconnect(context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Source interface {
|
type Source interface {
|
||||||
@ -65,7 +67,7 @@ func (p Proxy) Start() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
|
||||||
go func() {
|
go func() {
|
||||||
_, reconnectable := c.(Reconnectable)
|
_, reconnectable := c.(Reconnectable)
|
||||||
|
|
||||||
@ -74,7 +76,11 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
|||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect consumer `%v`\n", c)
|
log.Printf("attempting to connect consumer `%v`\n", c)
|
||||||
err = c.(Reconnectable).Reconnect()
|
err = c.(Reconnectable).Reconnect(ctx)
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -83,9 +89,19 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for c.IsAlive() {
|
for c.IsAlive() {
|
||||||
if err := c.Consume(<-p.proxyChan, g); err != nil {
|
select {
|
||||||
log.Println(err)
|
case <-ctx.Done():
|
||||||
break
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
case packet := <-p.proxyChan:
|
||||||
|
if err := c.Consume(ctx, packet, g); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println(err)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,7 +110,7 @@ func (p Proxy) AddConsumer(c Consumer, g MacGenerator) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
|
||||||
go func() {
|
go func() {
|
||||||
_, reconnectable := pr.(Reconnectable)
|
_, reconnectable := pr.(Reconnectable)
|
||||||
|
|
||||||
@ -103,20 +119,37 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) {
|
|||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect producer `%v`\n", pr)
|
log.Printf("attempting to connect producer `%v`\n", pr)
|
||||||
err = pr.(Reconnectable).Reconnect()
|
err = pr.(Reconnectable).Reconnect(ctx)
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
}
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
log.Printf("connected producer `%v`\n", pr)
|
log.Printf("connected producer `%v`\n", pr)
|
||||||
}
|
}
|
||||||
|
|
||||||
for pr.IsAlive() {
|
for pr.IsAlive() {
|
||||||
if packet, err := pr.Produce(v); err != nil {
|
if packet, err := pr.Produce(ctx, v); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
p.sinkChan <- packet
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
|
return
|
||||||
|
case p.sinkChan <- packet:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
22
tcp/flow.go
22
tcp/flow.go
@ -2,6 +2,7 @@ package tcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -124,21 +125,21 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
return f.Flow.Consume(p, g)
|
return f.Flow.Consume(ctx, p, g)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
return f.Flow.Produce(v)
|
return f.Flow.Produce(ctx, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return shared.ErrDeadConnection
|
return shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
@ -157,11 +158,16 @@ func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
|||||||
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
|
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
|
||||||
copy(prefixedData[4:], data)
|
copy(prefixedData[4:], data)
|
||||||
|
|
||||||
f.toConsume <- prefixedData
|
select {
|
||||||
|
case f.toConsume <- prefixedData:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return nil, shared.ErrDeadConnection
|
return nil, shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
@ -169,6 +175,8 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
|||||||
var data []byte
|
var data []byte
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
case data = <-f.produced:
|
case data = <-f.produced:
|
||||||
case err := <-f.produceErrors:
|
case err := <-f.produceErrors:
|
||||||
f.isAlive = false
|
f.isAlive = false
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -19,7 +20,7 @@ func TestFlow_Consume(t *testing.T) {
|
|||||||
|
|
||||||
flowA := NewFlowConn(testConn.SideA())
|
flowA := NewFlowConn(testConn.SideA())
|
||||||
|
|
||||||
err := flowA.Consume(testPacket, testMac)
|
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
@ -46,7 +47,7 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
_, err := testConn.SideB().Write(testMarshalled)
|
_, err := testConn.SideB().Write(testMarshalled)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, len(testContent), len(p.Contents()))
|
assert.Equal(t, len(testContent), len(p.Contents()))
|
||||||
})
|
})
|
||||||
@ -59,7 +60,7 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
_, err := testConn.SideB().Write(testMarshalled)
|
_, err := testConn.SideB().Write(testMarshalled)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, testContent, string(p.Contents()))
|
assert.Equal(t, testContent, string(p.Contents()))
|
||||||
})
|
})
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
|
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error {
|
||||||
laddr, err := net.ResolveTCPAddr("tcp", local)
|
laddr, err := net.ResolveTCPAddr("tcp", local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -32,8 +33,8 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
|||||||
|
|
||||||
log.Printf("received new tcp connection: %v\n", f)
|
log.Printf("received new tcp connection: %v\n", f)
|
||||||
|
|
||||||
p.AddConsumer(&f, g())
|
p.AddConsumer(ctx, &f, g())
|
||||||
p.AddProducer(&f, v())
|
p.AddProducer(ctx, &f, v())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type Congestion interface {
|
type Congestion interface {
|
||||||
Sequence() uint32
|
Sequence(ctx context.Context) (uint32, error)
|
||||||
NextAck() uint32
|
NextAck() uint32
|
||||||
NextNack() uint32
|
NextNack() uint32
|
||||||
|
|
||||||
ReceivedPacket(seq, nack, ack uint32)
|
ReceivedPacket(seq, nack, ack uint32)
|
||||||
|
|
||||||
AwaitEarlyUpdate(keepalive time.Duration) uint32
|
AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package congestion
|
package congestion
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"sort"
|
||||||
@ -94,7 +95,7 @@ func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NewReno) Sequence() uint32 {
|
func (c *NewReno) Sequence(ctx context.Context) (uint32, error) {
|
||||||
for len(c.inFlight) >= int(c.windowSize) {
|
for len(c.inFlight) >= int(c.windowSize) {
|
||||||
<-c.windowNotifier
|
<-c.windowNotifier
|
||||||
}
|
}
|
||||||
@ -102,7 +103,13 @@ func (c *NewReno) Sequence() uint32 {
|
|||||||
c.inFlightMu.Lock()
|
c.inFlightMu.Lock()
|
||||||
defer c.inFlightMu.Unlock()
|
defer c.inFlightMu.Unlock()
|
||||||
|
|
||||||
s := <-c.sequence
|
var s uint32
|
||||||
|
select {
|
||||||
|
case s = <-c.sequence:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
|
|
||||||
c.inFlight = append(c.inFlight, flightInfo{
|
c.inFlight = append(c.inFlight, flightInfo{
|
||||||
@ -111,7 +118,7 @@ func (c *NewReno) Sequence() uint32 {
|
|||||||
})
|
})
|
||||||
c.lastSent = t
|
c.lastSent = t
|
||||||
|
|
||||||
return s
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NewReno) NextAck() uint32 {
|
func (c *NewReno) NextAck() uint32 {
|
||||||
@ -126,7 +133,7 @@ func (c *NewReno) NextNack() uint32 {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
func (c *NewReno) AwaitEarlyUpdate(ctx context.Context, keepalive time.Duration) (uint32, error) {
|
||||||
for {
|
for {
|
||||||
rtt := time.Duration(math.Round(c.rttNanos))
|
rtt := time.Duration(math.Round(c.rttNanos))
|
||||||
time.Sleep(rtt / 2)
|
time.Sleep(rtt / 2)
|
||||||
@ -136,12 +143,12 @@ func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 {
|
|||||||
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
|
// CASE 1: waiting ACKs or NACKs and no message sent in the last half-RTT
|
||||||
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
|
// this targets arrival in 0.5+0.5 ± 0.5 RTTs (1±0.5 RTTs)
|
||||||
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
|
if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt/2)) {
|
||||||
return 0 // no ack needed
|
return 0, nil // no ack needed
|
||||||
}
|
}
|
||||||
|
|
||||||
// CASE 2: No message sent within the keepalive time
|
// CASE 2: No message sent within the keepalive time
|
||||||
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
|
if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) {
|
||||||
return c.Sequence() // require an ack
|
return c.Sequence(ctx) // require an ack
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,11 +89,10 @@ func (n *newRenoTest) RunSideA(ctx context.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
if ctx.Err() != nil {
|
seq, err := n.sideA.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
seq := n.sideA.AwaitEarlyUpdate(500 * time.Millisecond)
|
|
||||||
if seq != 0 {
|
if seq != 0 {
|
||||||
// skip keepalive
|
// skip keepalive
|
||||||
// required to ensure AwaitEarlyUpdate terminates
|
// required to ensure AwaitEarlyUpdate terminates
|
||||||
@ -123,11 +122,10 @@ func (n *newRenoTest) RunSideB(ctx context.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
if ctx.Err() != nil {
|
seq, err := n.sideB.AwaitEarlyUpdate(ctx, 500 * time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
seq := n.sideB.AwaitEarlyUpdate(500 * time.Millisecond)
|
|
||||||
if seq != 0 {
|
if seq != 0 {
|
||||||
// skip keepalive
|
// skip keepalive
|
||||||
// required to ensure AwaitEarlyUpdate terminates
|
// required to ensure AwaitEarlyUpdate terminates
|
||||||
@ -162,7 +160,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
// sleep to simulate preparing packet
|
// sleep to simulate preparing packet
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideA.Sequence()
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
c.aOutbound <- congestionPacket{
|
c.aOutbound <- congestionPacket{
|
||||||
seq: seq,
|
seq: seq,
|
||||||
@ -200,7 +198,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
// sleep to simulate preparing packet
|
// sleep to simulate preparing packet
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideA.Sequence()
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 20 {
|
if seq == 20 {
|
||||||
// Simulate packet loss of sequence 20
|
// Simulate packet loss of sequence 20
|
||||||
@ -246,7 +244,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideA.Sequence()
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
c.aOutbound <- congestionPacket{
|
c.aOutbound <- congestionPacket{
|
||||||
seq: seq,
|
seq: seq,
|
||||||
@ -261,7 +259,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideB.Sequence()
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
c.bOutbound <- congestionPacket{
|
c.bOutbound <- congestionPacket{
|
||||||
seq: seq,
|
seq: seq,
|
||||||
@ -306,7 +304,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideA.Sequence()
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 9 {
|
if seq == 9 {
|
||||||
// Simulate packet loss of sequence 9
|
// Simulate packet loss of sequence 9
|
||||||
@ -326,7 +324,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq := c.sideB.Sequence()
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 13 {
|
if seq == 13 {
|
||||||
// Simulate packet loss of sequence 13
|
// Simulate packet loss of sequence 13
|
||||||
|
@ -1,41 +1,26 @@
|
|||||||
package congestion
|
package congestion
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type None struct {
|
type None struct {}
|
||||||
sequence chan uint32
|
|
||||||
|
func NewNone() None {
|
||||||
|
return None{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNone() *None {
|
func (c None) String() string {
|
||||||
c := None{
|
|
||||||
sequence: make(chan uint32),
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
var s uint32
|
|
||||||
for {
|
|
||||||
if s == 0 {
|
|
||||||
s++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.sequence <- s
|
|
||||||
s++
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return &c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *None) String() string {
|
|
||||||
return fmt.Sprintf("{None}")
|
return fmt.Sprintf("{None}")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *None) ReceivedPacket(uint32, uint32, uint32) {}
|
func (c None) ReceivedPacket(uint32, uint32, uint32) {}
|
||||||
func (c *None) NextNack() uint32 { return 0 }
|
func (c None) NextNack() uint32 { return 0 }
|
||||||
func (c *None) NextAck() uint32 { return 0 }
|
func (c None) NextAck() uint32 { return 0 }
|
||||||
func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} }
|
func (c None) AwaitEarlyUpdate(ctx context.Context, _ time.Duration) (uint32, error) {
|
||||||
func (c *None) Sequence() uint32 { return <-c.sequence }
|
<-ctx.Done()
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
func (c None) Sequence(context.Context) (uint32, error) { return 0, nil }
|
||||||
|
87
udp/flow.go
87
udp/flow.go
@ -1,6 +1,7 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
@ -80,7 +81,7 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Reconnect() error {
|
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
@ -108,7 +109,10 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
|
|
||||||
// prod the connection once a second until we get an ack, then consider it alive
|
// prod the connection once a second until we get an ack, then consider it alive
|
||||||
go func() {
|
go func() {
|
||||||
seq := f.congestion.Sequence()
|
seq, err := f.congestion.Sequence(ctx)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
for !f.isAlive {
|
for !f.isAlive {
|
||||||
p := Packet{
|
p := Packet{
|
||||||
@ -124,11 +128,11 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = f.produceInternal(f.v, false)
|
_, _ = f.produceInternal(ctx, f.v, false)
|
||||||
}()
|
}()
|
||||||
go f.earlyUpdateLoop(f.g, f.keepalive)
|
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
|
||||||
|
|
||||||
if err := f.acceptPacket(conn); err != nil {
|
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,7 +144,7 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
if err := f.acceptPacket(conn); err != nil {
|
if err := f.readQueuePacket(ctx, conn); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -155,25 +159,25 @@ func (f *InitiatedFlow) Reconnect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error {
|
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
return f.Flow.Consume(p, g)
|
return f.Flow.Consume(ctx, p, g)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
f.mu.RLock()
|
f.mu.RLock()
|
||||||
defer f.mu.RUnlock()
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
return f.Flow.Produce(v)
|
return f.Flow.Produce(ctx, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) IsAlive() bool {
|
func (f *Flow) IsAlive() bool {
|
||||||
return f.isAlive
|
return f.isAlive
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
|
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return shared.ErrDeadConnection
|
return shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
@ -182,32 +186,43 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error {
|
|||||||
|
|
||||||
// Sequence is the congestion controllers opportunity to block
|
// Sequence is the congestion controllers opportunity to block
|
||||||
log.Println("awaiting sequence")
|
log.Println("awaiting sequence")
|
||||||
p := Packet{
|
seq, err := f.congestion.Sequence(ctx)
|
||||||
seq: f.congestion.Sequence(),
|
if err != nil {
|
||||||
data: pp,
|
return err
|
||||||
}
|
}
|
||||||
log.Println("received sequence")
|
log.Println("received sequence")
|
||||||
|
|
||||||
// Choose up to date ACK/NACK even after blocking
|
// Choose up to date ACK/NACK even after blocking
|
||||||
p.ack = f.congestion.NextAck()
|
p := Packet{
|
||||||
p.nack = f.congestion.NextNack()
|
seq: seq,
|
||||||
|
data: pp,
|
||||||
|
ack: f.congestion.NextAck(),
|
||||||
|
nack: f.congestion.NextNack(),
|
||||||
|
}
|
||||||
|
|
||||||
return f.sendPacket(p, g)
|
return f.sendPacket(p, g)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) {
|
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) {
|
||||||
if !f.isAlive {
|
if !f.isAlive {
|
||||||
return nil, shared.ErrDeadConnection
|
return nil, shared.ErrDeadConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.produceInternal(v, true)
|
return f.produceInternal(ctx, v, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) {
|
||||||
for once := true; mustReturn || once; once = false {
|
for once := true; mustReturn || once; once = false {
|
||||||
log.Println(f.congestion)
|
log.Println(f.congestion)
|
||||||
|
|
||||||
b, err := proxy.StripMac(<-f.inboundDatagrams, v)
|
var received []byte
|
||||||
|
select {
|
||||||
|
case received = <-f.inboundDatagrams:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := proxy.StripMac(received, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -232,8 +247,13 @@ func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Pack
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) handleDatagram(p []byte) {
|
func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
|
||||||
f.inboundDatagrams <- p
|
select {
|
||||||
|
case f.inboundDatagrams <- p:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
||||||
@ -249,27 +269,34 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) {
|
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) {
|
||||||
for f.isAlive {
|
for f.isAlive {
|
||||||
seq := f.congestion.AwaitEarlyUpdate(keepalive)
|
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("terminating earlyupdateloop for `%v`\n", f)
|
||||||
|
return
|
||||||
|
}
|
||||||
p := Packet{
|
p := Packet{
|
||||||
ack: f.congestion.NextAck(),
|
|
||||||
nack: f.congestion.NextNack(),
|
|
||||||
seq: seq,
|
seq: seq,
|
||||||
data: proxy.SimplePacket(nil),
|
data: proxy.SimplePacket(nil),
|
||||||
|
ack: f.congestion.NextAck(),
|
||||||
|
nack: f.congestion.NextNack(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = f.sendPacket(p, g)
|
err = f.sendPacket(p, g)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error sending early update packet: `%v`\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) acceptPacket(c PacketConn) error {
|
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
|
||||||
|
// TODO: Replace 6000 with MTU+header size
|
||||||
buf := make([]byte, 6000)
|
buf := make([]byte, 6000)
|
||||||
n, _, err := c.ReadFromUDP(buf)
|
n, _, err := c.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleDatagram(buf[:n])
|
return f.queueDatagram(ctx, buf[:n])
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -24,7 +25,7 @@ func TestFlow_Consume(t *testing.T) {
|
|||||||
flowA.writer = testConn.SideB()
|
flowA.writer = testConn.SideB()
|
||||||
flowA.isAlive = true
|
flowA.isAlive = true
|
||||||
|
|
||||||
err := flowA.Consume(testPacket, testMac)
|
err := flowA.Consume(context.Background(), testPacket, testMac)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
@ -63,10 +64,10 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
flowA.isAlive = true
|
flowA.isAlive = true
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := flowA.acceptPacket(testConn.SideB())
|
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}()
|
}()
|
||||||
p, err := flowA.Produce(testMac)
|
p, err := flowA.Produce(context.Background(), testMac)
|
||||||
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Len(t, p.Contents(), len(testContent))
|
assert.Len(t, p.Contents(), len(testContent))
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"net"
|
"net"
|
||||||
@ -25,7 +26,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
|
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error {
|
||||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -56,10 +57,11 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
|||||||
|
|
||||||
raddr := fromUdpAddress(*addr)
|
raddr := fromUdpAddress(*addr)
|
||||||
if f, exists := receivedConnections[raddr]; exists {
|
if f, exists := receivedConnections[raddr]; exists {
|
||||||
log.Println("existing flow")
|
log.Println("existing flow. queuing...")
|
||||||
log.Println("handling...")
|
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
||||||
f.handleDatagram(buf[:n])
|
|
||||||
log.Println("handled")
|
}
|
||||||
|
log.Println("queued")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,15 +76,15 @@ func NewListener(p *proxy.Proxy, local string, v func() proxy.MacVerifier, g fun
|
|||||||
|
|
||||||
log.Printf("received new udp connection: %v\n", f)
|
log.Printf("received new udp connection: %v\n", f)
|
||||||
|
|
||||||
go f.earlyUpdateLoop(g, 0)
|
go f.earlyUpdateLoop(ctx, g, 0)
|
||||||
|
|
||||||
receivedConnections[raddr] = &f
|
receivedConnections[raddr] = &f
|
||||||
|
|
||||||
p.AddConsumer(&f, g)
|
p.AddConsumer(ctx, &f, g)
|
||||||
p.AddProducer(&f, v)
|
p.AddProducer(ctx, &f, v)
|
||||||
|
|
||||||
log.Println("handling...")
|
log.Println("handling...")
|
||||||
f.handleDatagram(buf[:n])
|
f.queueDatagram(ctx, buf[:n])
|
||||||
log.Println("handled")
|
log.Println("handled")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
Loading…
Reference in New Issue
Block a user