new udp exchange
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Jake Hillion 2021-04-14 17:07:59 +01:00
parent c9c32349f2
commit c9596909f2
7 changed files with 215 additions and 60 deletions

View File

@ -10,6 +10,7 @@ import (
"mpbl3p/tcp" "mpbl3p/tcp"
"mpbl3p/udp" "mpbl3p/udp"
"mpbl3p/udp/congestion" "mpbl3p/udp/congestion"
"mpbl3p/udp/congestion/newreno"
"time" "time"
) )
@ -104,7 +105,7 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac
default: default:
fallthrough fallthrough
case "NewReno": case "NewReno":
c = func() udp.Congestion { return congestion.NewNewReno() } c = func() udp.Congestion { return newreno.NewNewReno() }
} }
if peer.RemoteHost != "" { if peer.RemoteHost != "" {

View File

@ -1,7 +1,9 @@
package proxy package proxy
import "context"
type Exchange interface { type Exchange interface {
Initial(ctx context.Context) (out []byte, err error)
Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error)
Complete() bool Complete() bool
Initial() (out []byte, err error)
Handle(in []byte) (out []byte, data []byte, err error)
} }

View File

@ -5,3 +5,4 @@ import "errors"
var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrBadChecksum = errors.New("the packet had a bad checksum")
var ErrDeadConnection = errors.New("the connection is dead") var ErrDeadConnection = errors.New("the connection is dead")
var ErrNotEnoughBytes = errors.New("not enough bytes") var ErrNotEnoughBytes = errors.New("not enough bytes")
var ErrBadExchange = errors.New("bad exchange")

View File

@ -0,0 +1,119 @@
package newreno
import (
"context"
"encoding/binary"
"math/rand"
"mpbl3p/shared"
"time"
)
func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) {
c.alive = false
c.wasInitial = true
c.startSequenceLoop(ctx)
var s uint32
select {
case s = <-c.sequence:
case <-ctx.Done():
return nil, ctx.Err()
}
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[8:12], s)
c.inFlight = []flightInfo{{time.Now(), s}}
return b, nil
}
func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) {
if c.alive || c.stopSequence == nil {
// reset
c.alive = false
c.startSequenceLoop(ctx)
}
// receive
if len(in) != 12 {
return nil, nil, shared.ErrBadExchange
}
rcvAck := binary.LittleEndian.Uint32(in[0:4])
rcvNack := binary.LittleEndian.Uint32(in[4:8])
rcvSeq := binary.LittleEndian.Uint32(in[8:12])
// verify
var ack, seq uint32
if rcvNack != 0 {
return nil, nil, shared.ErrBadExchange
}
if c.wasInitial {
if rcvAck == c.inFlight[0].sequence {
ack = rcvSeq
c.alive = true
} else {
return nil, nil, shared.ErrBadExchange
}
} else { // if !c.wasInitial
if rcvAck == 0 {
// theirs is a syn packet
ack = rcvSeq
select {
case seq = <-c.sequence:
case <-ctx.Done():
return nil, nil, ctx.Err()
}
c.inFlight = []flightInfo{{time.Now(), seq}}
} else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence {
ack = rcvSeq
c.alive = true
} else {
return nil, nil, shared.ErrBadExchange
}
}
// respond
b := make([]byte, 12)
binary.LittleEndian.PutUint32(b[0:4], ack)
binary.LittleEndian.PutUint32(b[8:12], seq)
return b, nil, nil
}
func (c *NewReno) Complete() bool {
return c.alive
}
func (c *NewReno) startSequenceLoop(ctx context.Context) {
if c.stopSequence != nil {
c.stopSequence()
}
var s uint32
for s == 0 {
s = rand.Uint32()
}
ctx, c.stopSequence = context.WithCancel(ctx)
go func() {
s := s
for {
if s == 0 {
s++
continue
}
select {
case c.sequence <- s:
case <-ctx.Done():
return
}
s++
}
}()
}

View File

@ -1,4 +1,4 @@
package congestion package newreno
import ( import (
"context" "context"
@ -14,7 +14,9 @@ const RttExponentialFactor = 0.1
const RttLossDelay = 1.5 const RttLossDelay = 1.5
type NewReno struct { type NewReno struct {
sequence chan uint32 sequence chan uint32
stopSequence context.CancelFunc
wasInitial, alive bool
inFlight []flightInfo inFlight []flightInfo
lastSent time.Time lastSent time.Time
@ -64,19 +66,6 @@ func NewNewReno() *NewReno {
slowStart: true, slowStart: true,
} }
go func() {
var s uint32
for {
if s == 0 {
s++
continue
}
c.sequence <- s
s++
}
}()
return &c return &c
} }

View File

@ -1,4 +1,4 @@
package congestion package newreno
import ( import (
"context" "context"

View File

@ -2,6 +2,7 @@ package udp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"mpbl3p/proxy" "mpbl3p/proxy"
@ -19,6 +20,7 @@ type PacketWriter interface {
type PacketConn interface { type PacketConn interface {
PacketWriter PacketWriter
SetReadDeadline(t time.Time) error
ReadFromUDP(b []byte) (int, *net.UDPAddr, error) ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
} }
@ -42,7 +44,6 @@ type Flow struct {
raddr *net.UDPAddr raddr *net.UDPAddr
isAlive bool isAlive bool
startup bool
congestion Congestion congestion Congestion
v proxy.MacVerifier v proxy.MacVerifier
@ -105,52 +106,82 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
} }
f.writer = conn f.writer = conn
f.startup = true
// prod the connection once a second until we get an ack, then consider it alive // prod the connection once a second until we get an ack, then consider it alive
go func() { var exchanges []proxy.Exchange
seq, err := f.congestion.Sequence(ctx)
if err != nil {
return
}
for !f.isAlive { if e, ok := f.congestion.(proxy.Exchange); ok {
if ctx.Err() != nil { exchanges = append(exchanges, e)
return
}
p := Packet{
ack: 0,
nack: 0,
seq: seq,
data: proxy.SimplePacket(nil),
}
_ = f.sendPacket(p, f.g)
time.Sleep(1 * time.Second)
}
}()
go func() {
_, _ = f.produceInternal(ctx, f.v, false)
}()
go f.earlyUpdateLoop(ctx, f.g, f.keepalive)
if err := f.readQueuePacket(ctx, conn); err != nil {
return err
} }
f.isAlive = true var exchangeData [][]byte
f.startup = false
for _, e := range exchanges {
i, err := e.Initial(ctx)
if err != nil {
return err
}
if err = f.sendPacket(proxy.SimplePacket(i), f.g); err != nil {
return err
}
for once := true; once || !e.Complete(); once = false {
if err := func() error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var recv []byte
if recv, err = f.readPacket(ctx, conn); err != nil {
return err
}
var resp, data []byte
if resp, data, err = e.Handle(ctx, recv); err != nil {
return err
}
if data != nil {
exchangeData = append(exchangeData, data)
}
if resp != nil {
if err = f.sendPacket(proxy.SimplePacket(resp), f.g); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
}
go func() { go func() {
for _, d := range exchangeData {
if err := f.queueDatagram(ctx, d); err != nil {
return
}
}
lockedAccept := func() { lockedAccept := func() {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
if err := f.readQueuePacket(ctx, conn); err != nil { var p []byte
if p, err = f.readPacket(ctx, conn); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return
}
log.Println(err) log.Println(err)
return
} }
if err := f.queueDatagram(ctx, p); err != nil {
return
}
} }
for f.isAlive { for f.isAlive {
@ -160,6 +191,7 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
log.Println("no longer alive") log.Println("no longer alive")
}() }()
f.isAlive = true
return nil return nil
} }
@ -260,7 +292,7 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
} }
} }
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { func (f *Flow) sendPacket(p proxy.Packet, g proxy.MacGenerator) error {
b := p.Marshal() b := p.Marshal()
b = proxy.AppendMac(b, g) b = proxy.AppendMac(b, g)
@ -294,13 +326,24 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepal
} }
} }
func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) {
// TODO: Replace 6000 with MTU+header size
buf := make([]byte, 6000) buf := make([]byte, 6000)
n, _, err := c.ReadFromUDP(buf)
if err != nil { if d, ok := ctx.Deadline(); ok {
return err if err := c.SetReadDeadline(d); err != nil {
return nil, err
}
} }
return f.queueDatagram(ctx, buf[:n]) n, _, err := c.ReadFromUDP(buf)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
if ctx.Err() != nil {
return nil, ctx.Err()
}
}
return nil, err
}
return buf[:n], nil
} }