exchanges #20
@ -11,6 +11,7 @@ import (
|
|||||||
"mpbl3p/tcp"
|
"mpbl3p/tcp"
|
||||||
"mpbl3p/udp"
|
"mpbl3p/udp"
|
||||||
"mpbl3p/udp/congestion"
|
"mpbl3p/udp/congestion"
|
||||||
|
"mpbl3p/udp/congestion/newreno"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ func buildUdp(
|
|||||||
default:
|
default:
|
||||||
fallthrough
|
fallthrough
|
||||||
case "NewReno":
|
case "NewReno":
|
||||||
c = func() udp.Congestion { return congestion.NewNewReno() }
|
c = func() udp.Congestion { return newreno.NewNewReno() }
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.RemoteHost != "" {
|
if peer.RemoteHost != "" {
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
package flags
|
package flags
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
goflags "github.com/jessevdk/go-flags"
|
goflags "github.com/jessevdk/go-flags"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
var PrintedHelpErr = goflags.ErrHelp
|
var PrintedHelpErr = goflags.ErrHelp
|
||||||
var NotEnoughArgs = errors.New("not enough arguments")
|
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`
|
Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"`
|
||||||
|
4
flags/locs_darwin.go
Normal file
4
flags/locs_darwin.go
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package flags
|
||||||
|
|
||||||
|
const DefaultConfigFile = "/usr/local/etc/netcombiner/%s"
|
||||||
|
const DefaultPidFile = "/var/run/netcombiner/%s.pid"
|
@ -1,6 +1,9 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type MockPerfectBiPacketConn struct {
|
type MockPerfectBiPacketConn struct {
|
||||||
directionA chan []byte
|
directionA chan []byte
|
||||||
@ -44,6 +47,10 @@ func (c MockPerfectPacketConn) LocalAddr() net.Addr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c MockPerfectPacketConn) SetReadDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||||
p := <-c.inbound
|
p := <-c.inbound
|
||||||
return copy(b, p), &net.UDPAddr{
|
return copy(b, p), &net.UDPAddr{
|
||||||
|
9
proxy/exchange.go
Normal file
9
proxy/exchange.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type Exchange interface {
|
||||||
|
Initial(ctx context.Context) (out []byte, err error)
|
||||||
|
Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error)
|
||||||
|
Complete() bool
|
||||||
|
}
|
@ -75,12 +75,13 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect consumer `%v`\n", c)
|
if err := ctx.Err(); err != nil {
|
||||||
err = c.(Reconnectable).Reconnect(ctx)
|
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
log.Printf("closed consumer `%v` (context)\n", c)
|
log.Printf("closed consumer `%v` (context)\n", c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("attempting to connect consumer `%v`\n", c)
|
||||||
|
err = c.(Reconnectable).Reconnect(ctx)
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -118,12 +119,13 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
|
|||||||
if reconnectable {
|
if reconnectable {
|
||||||
var err error
|
var err error
|
||||||
for once := true; err != nil || once; once = false {
|
for once := true; err != nil || once; once = false {
|
||||||
log.Printf("attempting to connect producer `%v`\n", pr)
|
if err := ctx.Err(); err != nil {
|
||||||
err = pr.(Reconnectable).Reconnect(ctx)
|
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
log.Printf("closed producer `%v` (context)\n", pr)
|
log.Printf("closed producer `%v` (context)\n", pr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("attempting to connect producer `%v`\n", pr)
|
||||||
|
err = pr.(Reconnectable).Reconnect(ctx)
|
||||||
if !once {
|
if !once {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
@ -5,3 +5,4 @@ import "errors"
|
|||||||
var ErrBadChecksum = errors.New("the packet had a bad checksum")
|
var ErrBadChecksum = errors.New("the packet had a bad checksum")
|
||||||
var ErrDeadConnection = errors.New("the connection is dead")
|
var ErrDeadConnection = errors.New("the connection is dead")
|
||||||
var ErrNotEnoughBytes = errors.New("not enough bytes")
|
var ErrNotEnoughBytes = errors.New("not enough bytes")
|
||||||
|
var ErrBadExchange = errors.New("bad exchange")
|
||||||
|
119
udp/congestion/newreno/exchange.go
Normal file
119
udp/congestion/newreno/exchange.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package newreno
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"math/rand"
|
||||||
|
"mpbl3p/shared"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) {
|
||||||
|
c.alive = false
|
||||||
|
c.wasInitial = true
|
||||||
|
c.startSequenceLoop(ctx)
|
||||||
|
|
||||||
|
var s uint32
|
||||||
|
select {
|
||||||
|
case s = <-c.sequence:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 12)
|
||||||
|
binary.LittleEndian.PutUint32(b[8:12], s)
|
||||||
|
|
||||||
|
c.inFlight = []flightInfo{{time.Now(), s}}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) {
|
||||||
|
if c.alive || c.stopSequence == nil {
|
||||||
|
// reset
|
||||||
|
c.alive = false
|
||||||
|
c.startSequenceLoop(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// receive
|
||||||
|
if len(in) != 12 {
|
||||||
|
return nil, nil, shared.ErrBadExchange
|
||||||
|
}
|
||||||
|
|
||||||
|
rcvAck := binary.LittleEndian.Uint32(in[0:4])
|
||||||
|
rcvNack := binary.LittleEndian.Uint32(in[4:8])
|
||||||
|
rcvSeq := binary.LittleEndian.Uint32(in[8:12])
|
||||||
|
|
||||||
|
// verify
|
||||||
|
if rcvNack != 0 {
|
||||||
|
return nil, nil, shared.ErrBadExchange
|
||||||
|
}
|
||||||
|
|
||||||
|
var seq uint32
|
||||||
|
|
||||||
|
if c.wasInitial {
|
||||||
|
if rcvAck == c.inFlight[0].sequence {
|
||||||
|
c.ack, c.lastAck = rcvSeq, rcvSeq
|
||||||
|
c.alive, c.inFlight = true, nil
|
||||||
|
} else {
|
||||||
|
return nil, nil, shared.ErrBadExchange
|
||||||
|
}
|
||||||
|
} else { // if !c.wasInitial
|
||||||
|
if rcvAck == 0 {
|
||||||
|
// theirs is a syn packet
|
||||||
|
c.ack, c.lastAck = rcvSeq, rcvSeq
|
||||||
|
|
||||||
|
select {
|
||||||
|
case seq = <-c.sequence:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.inFlight = []flightInfo{{time.Now(), seq}}
|
||||||
|
} else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence {
|
||||||
|
c.alive, c.inFlight = true, nil
|
||||||
|
} else {
|
||||||
|
return nil, nil, shared.ErrBadExchange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// respond
|
||||||
|
b := make([]byte, 12)
|
||||||
|
binary.LittleEndian.PutUint32(b[0:4], c.ack)
|
||||||
|
binary.LittleEndian.PutUint32(b[8:12], seq)
|
||||||
|
|
||||||
|
return b, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) Complete() bool {
|
||||||
|
return c.alive
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NewReno) startSequenceLoop(ctx context.Context) {
|
||||||
|
if c.stopSequence != nil {
|
||||||
|
c.stopSequence()
|
||||||
|
}
|
||||||
|
|
||||||
|
var s uint32
|
||||||
|
for s == 0 {
|
||||||
|
s = rand.Uint32()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, c.stopSequence = context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
s := s
|
||||||
|
for {
|
||||||
|
if s == 0 {
|
||||||
|
s++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.sequence <- s:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s++
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
57
udp/congestion/newreno/exchange_test.go
Normal file
57
udp/congestion/newreno/exchange_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package newreno
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewReno_InitialHandle(t *testing.T) {
|
||||||
|
t.Run("InitialAckNackAreZero", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
a := NewNewReno()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
initial, err := a.Initial(ctx)
|
||||||
|
|
||||||
|
ack := binary.LittleEndian.Uint32(initial[0:4])
|
||||||
|
nack := binary.LittleEndian.Uint32(initial[4:8])
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Zero(t, ack)
|
||||||
|
assert.Zero(t, nack)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InitialHandledWithAck", func(t *testing.T) {
|
||||||
|
// ASSIGN
|
||||||
|
a := NewNewReno()
|
||||||
|
b := NewNewReno()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
initial, err := a.Initial(ctx)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
initialSeq := binary.LittleEndian.Uint32(initial[8:12])
|
||||||
|
|
||||||
|
response, data, err := b.Handle(ctx, initial)
|
||||||
|
|
||||||
|
ack := binary.LittleEndian.Uint32(response[0:4])
|
||||||
|
nack := binary.LittleEndian.Uint32(response[4:8])
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, initialSeq, ack)
|
||||||
|
assert.Zero(t, nack)
|
||||||
|
|
||||||
|
assert.Nil(t, data)
|
||||||
|
})
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package congestion
|
package newreno
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -11,10 +11,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const RttExponentialFactor = 0.1
|
const RttExponentialFactor = 0.1
|
||||||
const RttLossDelay = 1.5
|
const RttLossDelay = 0.5
|
||||||
|
|
||||||
type NewReno struct {
|
type NewReno struct {
|
||||||
sequence chan uint32
|
sequence chan uint32
|
||||||
|
stopSequence context.CancelFunc
|
||||||
|
wasInitial, alive bool
|
||||||
|
|
||||||
inFlight []flightInfo
|
inFlight []flightInfo
|
||||||
lastSent time.Time
|
lastSent time.Time
|
||||||
@ -64,19 +66,6 @@ func NewNewReno() *NewReno {
|
|||||||
slowStart: true,
|
slowStart: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
var s uint32
|
|
||||||
for {
|
|
||||||
if s == 0 {
|
|
||||||
s++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.sequence <- s
|
|
||||||
s++
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package congestion
|
package newreno
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -21,8 +21,8 @@ type newRenoTest struct {
|
|||||||
halfRtt time.Duration
|
halfRtt time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNewRenoTest(rtt time.Duration) *newRenoTest {
|
func newNewRenoTest(ctx context.Context, rtt time.Duration) *newRenoTest {
|
||||||
return &newRenoTest{
|
nr := &newRenoTest{
|
||||||
sideA: NewNewReno(),
|
sideA: NewNewReno(),
|
||||||
sideB: NewNewReno(),
|
sideB: NewNewReno(),
|
||||||
|
|
||||||
@ -34,6 +34,15 @@ func newNewRenoTest(rtt time.Duration) *newRenoTest {
|
|||||||
|
|
||||||
halfRtt: rtt / 2,
|
halfRtt: rtt / 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p, _ := nr.sideA.Initial(ctx)
|
||||||
|
p, _, _ = nr.sideB.Handle(ctx, p)
|
||||||
|
p, _, _ = nr.sideA.Handle(ctx, p)
|
||||||
|
|
||||||
|
nr.sideB.ReceivedPacket(0, nr.sideA.NextAck(), nr.sideA.NextNack())
|
||||||
|
nr.sideA.ReceivedPacket(0, nr.sideB.NextAck(), nr.sideB.NextNack())
|
||||||
|
|
||||||
|
return nr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *newRenoTest) Start(ctx context.Context) {
|
func (n *newRenoTest) Start(ctx context.Context) {
|
||||||
@ -151,11 +160,14 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
c := newNewRenoTest(rtt)
|
c := newNewRenoTest(ctx, rtt)
|
||||||
c.Start(ctx)
|
c.Start(ctx)
|
||||||
c.RunSideA(ctx)
|
c.RunSideA(ctx)
|
||||||
c.RunSideB(ctx)
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
sideAinitialAck := c.sideA.ack
|
||||||
|
sideBinitialAck := c.sideB.ack
|
||||||
|
|
||||||
// ACT
|
// ACT
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 0; i < numPackets; i++ {
|
||||||
// sleep to simulate preparing packet
|
// sleep to simulate preparing packet
|
||||||
@ -175,10 +187,10 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
// ASSERT
|
// ASSERT
|
||||||
|
|
||||||
assert.Equal(t, uint32(0), c.sideA.nack)
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
assert.Equal(t, uint32(0), c.sideA.ack)
|
assert.Equal(t, sideAinitialAck, c.sideA.ack)
|
||||||
|
|
||||||
assert.Equal(t, uint32(0), c.sideB.nack)
|
assert.Equal(t, uint32(0), c.sideB.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SequenceLoss", func(t *testing.T) {
|
t.Run("SequenceLoss", func(t *testing.T) {
|
||||||
@ -189,18 +201,21 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
c := newNewRenoTest(rtt)
|
c := newNewRenoTest(ctx, rtt)
|
||||||
c.Start(ctx)
|
c.Start(ctx)
|
||||||
c.RunSideA(ctx)
|
c.RunSideA(ctx)
|
||||||
c.RunSideB(ctx)
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
sideAinitialAck := c.sideA.ack
|
||||||
|
sideBinitialAck := c.sideB.ack
|
||||||
|
|
||||||
// ACT
|
// ACT
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 1; i <= numPackets; i++ {
|
||||||
// sleep to simulate preparing packet
|
// sleep to simulate preparing packet
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq, _ := c.sideA.Sequence(ctx)
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 20 {
|
if i == 20 {
|
||||||
// Simulate packet loss of sequence 20
|
// Simulate packet loss of sequence 20
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -217,10 +232,10 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
// ASSERT
|
// ASSERT
|
||||||
|
|
||||||
assert.Equal(t, uint32(0), c.sideA.nack)
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
assert.Equal(t, uint32(0), c.sideA.ack)
|
assert.Equal(t, sideAinitialAck, c.sideA.ack)
|
||||||
|
|
||||||
assert.Equal(t, uint32(20), c.sideB.nack)
|
assert.Equal(t, sideBinitialAck+uint32(20), c.sideB.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -233,16 +248,19 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
c := newNewRenoTest(rtt)
|
c := newNewRenoTest(ctx, rtt)
|
||||||
c.Start(ctx)
|
c.Start(ctx)
|
||||||
c.RunSideA(ctx)
|
c.RunSideA(ctx)
|
||||||
c.RunSideB(ctx)
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
sideAinitialAck := c.sideA.ack
|
||||||
|
sideBinitialAck := c.sideB.ack
|
||||||
|
|
||||||
// ACT
|
// ACT
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 1; i <= numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq, _ := c.sideA.Sequence(ctx)
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
@ -257,7 +275,7 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 1; i <= numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq, _ := c.sideB.Sequence(ctx)
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
@ -279,10 +297,10 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
// ASSERT
|
// ASSERT
|
||||||
|
|
||||||
assert.Equal(t, uint32(0), c.sideA.nack)
|
assert.Equal(t, uint32(0), c.sideA.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideA.ack)
|
assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
|
||||||
|
|
||||||
assert.Equal(t, uint32(0), c.sideB.nack)
|
assert.Equal(t, uint32(0), c.sideB.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SequenceLoss", func(t *testing.T) {
|
t.Run("SequenceLoss", func(t *testing.T) {
|
||||||
@ -293,20 +311,23 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
c := newNewRenoTest(rtt)
|
c := newNewRenoTest(ctx, rtt)
|
||||||
c.Start(ctx)
|
c.Start(ctx)
|
||||||
c.RunSideA(ctx)
|
c.RunSideA(ctx)
|
||||||
c.RunSideB(ctx)
|
c.RunSideB(ctx)
|
||||||
|
|
||||||
|
sideAinitialAck := c.sideA.ack
|
||||||
|
sideBinitialAck := c.sideB.ack
|
||||||
|
|
||||||
// ACT
|
// ACT
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 1; i <= numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq, _ := c.sideA.Sequence(ctx)
|
seq, _ := c.sideA.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 9 {
|
if i == 9 {
|
||||||
// Simulate packet loss of sequence 9
|
// Simulate packet loss of sequence 9
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -322,11 +343,11 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < numPackets; i++ {
|
for i := 1; i <= numPackets; i++ {
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
seq, _ := c.sideB.Sequence(ctx)
|
seq, _ := c.sideB.Sequence(ctx)
|
||||||
|
|
||||||
if seq == 13 {
|
if i == 13 {
|
||||||
// Simulate packet loss of sequence 13
|
// Simulate packet loss of sequence 13
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -348,11 +369,11 @@ func TestNewReno_Congestion(t *testing.T) {
|
|||||||
|
|
||||||
// ASSERT
|
// ASSERT
|
||||||
|
|
||||||
assert.Equal(t, uint32(13), c.sideA.nack)
|
assert.Equal(t, sideAinitialAck+uint32(13), c.sideA.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideA.ack)
|
assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack)
|
||||||
|
|
||||||
assert.Equal(t, uint32(9), c.sideB.nack)
|
assert.Equal(t, sideBinitialAck+uint32(9), c.sideB.nack)
|
||||||
assert.Equal(t, uint32(numPackets), c.sideB.ack)
|
assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
156
udp/flow.go
156
udp/flow.go
@ -7,7 +7,6 @@ import (
|
|||||||
"mpbl3p/proxy"
|
"mpbl3p/proxy"
|
||||||
"mpbl3p/shared"
|
"mpbl3p/shared"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,29 +18,15 @@ type PacketWriter interface {
|
|||||||
|
|
||||||
type PacketConn interface {
|
type PacketConn interface {
|
||||||
PacketWriter
|
PacketWriter
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
|
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type InitiatedFlow struct {
|
|
||||||
Local func() string
|
|
||||||
Remote string
|
|
||||||
|
|
||||||
keepalive time.Duration
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
Flow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *InitiatedFlow) String() string {
|
|
||||||
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
writer PacketWriter
|
writer PacketWriter
|
||||||
raddr *net.UDPAddr
|
raddr *net.UDPAddr
|
||||||
|
|
||||||
isAlive bool
|
isAlive bool
|
||||||
startup bool
|
|
||||||
congestion Congestion
|
congestion Congestion
|
||||||
|
|
||||||
verifiers []proxy.MacVerifier
|
verifiers []proxy.MacVerifier
|
||||||
@ -54,24 +39,6 @@ func (f Flow) String() string {
|
|||||||
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
|
return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitiateFlow(
|
|
||||||
local func() string,
|
|
||||||
remote string,
|
|
||||||
vs []proxy.MacVerifier,
|
|
||||||
gs []proxy.MacGenerator,
|
|
||||||
c Congestion,
|
|
||||||
keepalive time.Duration,
|
|
||||||
) (*InitiatedFlow, error) {
|
|
||||||
f := InitiatedFlow{
|
|
||||||
Local: local,
|
|
||||||
Remote: remote,
|
|
||||||
Flow: newFlow(c, vs, gs),
|
|
||||||
keepalive: keepalive,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
|
func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
|
||||||
return Flow{
|
return Flow{
|
||||||
inboundDatagrams: make(chan []byte),
|
inboundDatagrams: make(chan []byte),
|
||||||
@ -81,102 +48,6 @@ func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
|
|
||||||
if f.isAlive {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
f.writer = conn
|
|
||||||
f.startup = true
|
|
||||||
|
|
||||||
// prod the connection once a second until we get an ack, then consider it alive
|
|
||||||
go func() {
|
|
||||||
seq, err := f.congestion.Sequence(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for !f.isAlive {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Packet{
|
|
||||||
ack: 0,
|
|
||||||
nack: 0,
|
|
||||||
seq: seq,
|
|
||||||
data: proxy.SimplePacket(nil),
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = f.sendPacket(p)
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = f.produceInternal(ctx, false)
|
|
||||||
}()
|
|
||||||
go f.earlyUpdateLoop(ctx, f.keepalive)
|
|
||||||
|
|
||||||
if err := f.readQueuePacket(ctx, conn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
f.isAlive = true
|
|
||||||
f.startup = false
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
lockedAccept := func() {
|
|
||||||
f.mu.RLock()
|
|
||||||
defer f.mu.RUnlock()
|
|
||||||
|
|
||||||
if err := f.readQueuePacket(ctx, conn); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for f.isAlive {
|
|
||||||
log.Println("alive and listening for packets")
|
|
||||||
lockedAccept()
|
|
||||||
}
|
|
||||||
log.Println("no longer alive")
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error {
|
|
||||||
f.mu.RLock()
|
|
||||||
defer f.mu.RUnlock()
|
|
||||||
|
|
||||||
return f.Flow.Consume(ctx, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) {
|
|
||||||
f.mu.RLock()
|
|
||||||
defer f.mu.RUnlock()
|
|
||||||
|
|
||||||
return f.Flow.Produce(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Flow) IsAlive() bool {
|
func (f *Flow) IsAlive() bool {
|
||||||
return f.isAlive
|
return f.isAlive
|
||||||
}
|
}
|
||||||
@ -265,7 +136,7 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) sendPacket(p Packet) error {
|
func (f *Flow) sendPacket(p proxy.Packet) error {
|
||||||
b := p.Marshal()
|
b := p.Marshal()
|
||||||
|
|
||||||
for _, g := range f.generators {
|
for _, g := range f.generators {
|
||||||
@ -302,13 +173,24 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error {
|
func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) {
|
||||||
// TODO: Replace 6000 with MTU+header size
|
|
||||||
buf := make([]byte, 6000)
|
buf := make([]byte, 6000)
|
||||||
n, _, err := c.ReadFromUDP(buf)
|
|
||||||
if err != nil {
|
if d, ok := ctx.Deadline(); ok {
|
||||||
return err
|
if err := c.SetReadDeadline(d); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.queueDatagram(ctx, buf[:n])
|
n, _, err := c.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[:n], nil
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,9 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
flowA.isAlive = true
|
flowA.isAlive = true
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
|
p, err := flowA.readPacket(context.Background(), testConn.SideB())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = flowA.queueDatagram(context.Background(), p)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}()
|
}()
|
||||||
p, err := flowA.Produce(context.Background())
|
p, err := flowA.Produce(context.Background())
|
||||||
@ -143,7 +145,9 @@ func TestFlow_Produce(t *testing.T) {
|
|||||||
flowA.isAlive = true
|
flowA.isAlive = true
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
|
p, err := flowA.readPacket(context.Background(), testConn.SideB())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = flowA.queueDatagram(context.Background(), p)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}()
|
}()
|
||||||
p, err := flowA.Produce(context.Background())
|
p, err := flowA.Produce(context.Background())
|
||||||
|
142
udp/inbound_flow.go
Normal file
142
udp/inbound_flow.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InboundFlow struct {
|
||||||
|
inboundDatagrams chan []byte
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
Flow
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInboundFlow(f Flow) (*InboundFlow, error) {
|
||||||
|
fi := InboundFlow{
|
||||||
|
inboundDatagrams: make(chan []byte),
|
||||||
|
Flow: f,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &fi, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InboundFlow) queueDatagram(ctx context.Context, p []byte) error {
|
||||||
|
select {
|
||||||
|
case f.inboundDatagrams <- p:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InboundFlow) processPackets(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
f.mu.Lock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for once := true; err != nil || once; once = false {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = f.handleExchanges(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
var p []byte
|
||||||
|
select {
|
||||||
|
case p = <-f.inboundDatagrams:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Check if p means redo exchanges
|
||||||
|
if false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case f.Flow.inboundDatagrams <- p:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InboundFlow) handleExchanges(ctx context.Context) error {
|
||||||
|
var exchanges []proxy.Exchange
|
||||||
|
|
||||||
|
if e, ok := f.congestion.(proxy.Exchange); ok {
|
||||||
|
exchanges = append(exchanges, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exchangeData [][]byte
|
||||||
|
|
||||||
|
for _, e := range exchanges {
|
||||||
|
for once := true; !e.Complete() || once; once = false {
|
||||||
|
if err := func() (err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var recv []byte
|
||||||
|
select {
|
||||||
|
case recv = <-f.inboundDatagrams:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range f.verifiers {
|
||||||
|
v := f.verifiers[len(f.verifiers)-i-1]
|
||||||
|
|
||||||
|
recv, err = proxy.StripMac(recv, v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp, data []byte
|
||||||
|
if resp, data, err = e.Handle(ctx, recv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if data != nil {
|
||||||
|
exchangeData = append(exchangeData, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Consume(ctx, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *InboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(ctx)
|
||||||
|
}
|
@ -16,9 +16,7 @@ type ComparableUdpAddress struct {
|
|||||||
|
|
||||||
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
|
||||||
var ip [16]byte
|
var ip [16]byte
|
||||||
for i, b := range []byte(address.IP) {
|
copy(ip[:], address.IP)
|
||||||
ip[i] = b
|
|
||||||
}
|
|
||||||
|
|
||||||
return ComparableUdpAddress{
|
return ComparableUdpAddress{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
@ -47,12 +45,7 @@ func NewListener(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pconn.SetWriteBuffer(0)
|
receivedConnections := make(map[ComparableUdpAddress]*InboundFlow)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
receivedConnections := make(map[ComparableUdpAddress]*Flow)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
@ -70,10 +63,11 @@ func NewListener(
|
|||||||
}
|
}
|
||||||
|
|
||||||
raddr := fromUdpAddress(*addr)
|
raddr := fromUdpAddress(*addr)
|
||||||
if f, exists := receivedConnections[raddr]; exists {
|
if fi, exists := receivedConnections[raddr]; exists {
|
||||||
log.Println("existing flow. queuing...")
|
log.Println("existing flow. queuing...")
|
||||||
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
|
||||||
|
log.Println("error")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
log.Println("queued")
|
log.Println("queued")
|
||||||
continue
|
continue
|
||||||
@ -94,21 +88,28 @@ func NewListener(
|
|||||||
f.raddr = addr
|
f.raddr = addr
|
||||||
f.isAlive = true
|
f.isAlive = true
|
||||||
|
|
||||||
|
fi, err := newInboundFlow(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
log.Printf("received new udp connection: %v\n", f)
|
log.Printf("received new udp connection: %v\n", f)
|
||||||
|
|
||||||
go f.earlyUpdateLoop(ctx, 0)
|
go fi.processPackets(ctx)
|
||||||
|
go fi.earlyUpdateLoop(ctx, 0)
|
||||||
|
|
||||||
receivedConnections[raddr] = &f
|
receivedConnections[raddr] = fi
|
||||||
|
|
||||||
if enableConsumers {
|
if enableConsumers {
|
||||||
p.AddConsumer(ctx, &f)
|
p.AddConsumer(ctx, fi)
|
||||||
}
|
}
|
||||||
if enableProducers {
|
if enableProducers {
|
||||||
p.AddProducer(ctx, &f)
|
p.AddProducer(ctx, fi)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("handling...")
|
log.Println("handling...")
|
||||||
if err := f.queueDatagram(ctx, buf[:n]); err != nil {
|
if err := fi.queueDatagram(ctx, buf[:n]); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("handled")
|
log.Println("handled")
|
||||||
|
181
udp/outbound_flow.go
Normal file
181
udp/outbound_flow.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"mpbl3p/proxy"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OutboundFlow struct {
|
||||||
|
Local func() string
|
||||||
|
Remote string
|
||||||
|
|
||||||
|
g proxy.MacGenerator
|
||||||
|
keepalive time.Duration
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
Flow
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitiateFlow(
|
||||||
|
local func() string,
|
||||||
|
remote string,
|
||||||
|
vs []proxy.MacVerifier,
|
||||||
|
gs []proxy.MacGenerator,
|
||||||
|
c Congestion,
|
||||||
|
keepalive time.Duration,
|
||||||
|
) (*OutboundFlow, error) {
|
||||||
|
f := OutboundFlow{
|
||||||
|
Local: local,
|
||||||
|
Remote: remote,
|
||||||
|
Flow: newFlow(c, vs, gs),
|
||||||
|
keepalive: keepalive,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *OutboundFlow) String() string {
|
||||||
|
return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *OutboundFlow) Reconnect(ctx context.Context) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
if f.isAlive {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr, err := net.ResolveUDPAddr("udp", f.Local())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.writer = conn
|
||||||
|
|
||||||
|
// prod the connection once a second until we get an ack, then consider it alive
|
||||||
|
var exchanges []proxy.Exchange
|
||||||
|
|
||||||
|
if e, ok := f.congestion.(proxy.Exchange); ok {
|
||||||
|
exchanges = append(exchanges, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exchangeData [][]byte
|
||||||
|
|
||||||
|
for _, e := range exchanges {
|
||||||
|
i, err := e.Initial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = f.sendPacket(proxy.SimplePacket(i)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for once := true; !e.Complete() || once; once = false {
|
||||||
|
if err := func() error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var recv []byte
|
||||||
|
if recv, err = f.readPacket(ctx, conn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range f.verifiers {
|
||||||
|
v := f.verifiers[len(f.verifiers)-i-1]
|
||||||
|
|
||||||
|
recv, err = proxy.StripMac(recv, v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp, data []byte
|
||||||
|
if resp, data, err = e.Handle(ctx, recv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if data != nil {
|
||||||
|
exchangeData = append(exchangeData, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for _, d := range exchangeData {
|
||||||
|
if err := f.queueDatagram(ctx, d); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lockedAccept := func() {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
var p []byte
|
||||||
|
if p, err = f.readPacket(ctx, conn); err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.queueDatagram(ctx, p); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for f.isAlive {
|
||||||
|
log.Println("alive and listening for packets")
|
||||||
|
lockedAccept()
|
||||||
|
}
|
||||||
|
log.Println("no longer alive")
|
||||||
|
}()
|
||||||
|
|
||||||
|
f.isAlive = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet) error {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Consume(ctx, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *OutboundFlow) Produce(ctx context.Context) (proxy.Packet, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
|
||||||
|
return f.Flow.Produce(ctx)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user