diff --git a/config/builder.go b/config/builder.go index b83995c..ba84eb7 100644 --- a/config/builder.go +++ b/config/builder.go @@ -10,6 +10,7 @@ import ( "mpbl3p/tcp" "mpbl3p/udp" "mpbl3p/udp/congestion" + "mpbl3p/udp/congestion/newreno" "time" ) @@ -104,7 +105,7 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac default: fallthrough case "NewReno": - c = func() udp.Congestion { return congestion.NewNewReno() } + c = func() udp.Congestion { return newreno.NewNewReno() } } if peer.RemoteHost != "" { diff --git a/proxy/exchange.go b/proxy/exchange.go index 35c6460..dae3bb6 100644 --- a/proxy/exchange.go +++ b/proxy/exchange.go @@ -1,7 +1,9 @@ package proxy +import "context" + type Exchange interface { + Initial(ctx context.Context) (out []byte, err error) + Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) Complete() bool - Initial() (out []byte, err error) - Handle(in []byte) (out []byte, data []byte, err error) } diff --git a/shared/errors.go b/shared/errors.go index 0db0c92..ec5b537 100644 --- a/shared/errors.go +++ b/shared/errors.go @@ -5,3 +5,4 @@ import "errors" var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrDeadConnection = errors.New("the connection is dead") var ErrNotEnoughBytes = errors.New("not enough bytes") +var ErrBadExchange = errors.New("bad exchange") diff --git a/udp/congestion/newreno/exchange.go b/udp/congestion/newreno/exchange.go new file mode 100644 index 0000000..409bb0a --- /dev/null +++ b/udp/congestion/newreno/exchange.go @@ -0,0 +1,119 @@ +package newreno + +import ( + "context" + "encoding/binary" + "math/rand" + "mpbl3p/shared" + "time" +) + +func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) { + c.alive = false + c.wasInitial = true + c.startSequenceLoop(ctx) + + var s uint32 + select { + case s = <-c.sequence: + case <-ctx.Done(): + return nil, ctx.Err() + } + + b := make([]byte, 12) + binary.LittleEndian.PutUint32(b[8:12], s) + + c.inFlight = []flightInfo{{time.Now(), s}} + + return b, nil +} + +func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) { + if c.alive || c.stopSequence == nil { + // reset + c.alive = false + c.startSequenceLoop(ctx) + } + + // receive + if len(in) != 12 { + return nil, nil, shared.ErrBadExchange + } + + rcvAck := binary.LittleEndian.Uint32(in[0:4]) + rcvNack := binary.LittleEndian.Uint32(in[4:8]) + rcvSeq := binary.LittleEndian.Uint32(in[8:12]) + + // verify + var ack, seq uint32 + + if rcvNack != 0 { + return nil, nil, shared.ErrBadExchange + } + + if c.wasInitial { + if rcvAck == c.inFlight[0].sequence { + ack = rcvSeq + c.alive = true + } else { + return nil, nil, shared.ErrBadExchange + } + } else { // if !c.wasInitial + if rcvAck == 0 { + // theirs is a syn packet + ack = rcvSeq + select { + case seq = <-c.sequence: + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + + c.inFlight = []flightInfo{{time.Now(), seq}} + } else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence { + ack = rcvSeq + c.alive = true + } else { + return nil, nil, shared.ErrBadExchange + } + } + + // respond + b := make([]byte, 12) + binary.LittleEndian.PutUint32(b[0:4], ack) + binary.LittleEndian.PutUint32(b[8:12], seq) + + return b, nil, nil +} + +func (c *NewReno) Complete() bool { + return c.alive +} + +func (c *NewReno) startSequenceLoop(ctx context.Context) { + if c.stopSequence != nil { + c.stopSequence() + } + + var s uint32 + for s == 0 { + s = rand.Uint32() + } + + ctx, c.stopSequence = context.WithCancel(ctx) + go func() { + s := s + for { + if s == 0 { + s++ + continue + } + + select { + case c.sequence <- s: + case <-ctx.Done(): + return + } + s++ + } + }() +} diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno/newreno.go similarity index 97% rename from udp/congestion/newreno.go rename to udp/congestion/newreno/newreno.go index d796e11..9ee1a8a 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno/newreno.go @@ -1,4 +1,4 @@ -package congestion +package newreno import ( "context" @@ -14,7 +14,9 @@ const RttExponentialFactor = 0.1 const RttLossDelay = 1.5 type NewReno struct { - sequence chan uint32 + sequence chan uint32 + stopSequence context.CancelFunc + wasInitial, alive bool inFlight []flightInfo lastSent time.Time @@ -64,19 +66,6 @@ func NewNewReno() *NewReno { slowStart: true, } - go func() { - var s uint32 - for { - if s == 0 { - s++ - continue - } - - c.sequence <- s - s++ - } - }() - return &c } diff --git a/udp/congestion/newreno_test.go b/udp/congestion/newreno/newreno_test.go similarity index 99% rename from udp/congestion/newreno_test.go rename to udp/congestion/newreno/newreno_test.go index f97dfb9..813e400 100644 --- a/udp/congestion/newreno_test.go +++ b/udp/congestion/newreno/newreno_test.go @@ -1,4 +1,4 @@ -package congestion +package newreno import ( "context" diff --git a/udp/flow.go b/udp/flow.go index d552d93..2ab941f 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -2,6 +2,7 @@ package udp import ( "context" + "errors" "fmt" "log" "mpbl3p/proxy" @@ -19,6 +20,7 @@ type PacketWriter interface { type PacketConn interface { PacketWriter + SetReadDeadline(t time.Time) error ReadFromUDP(b []byte) (int, *net.UDPAddr, error) } @@ -42,7 +44,6 @@ type Flow struct { raddr *net.UDPAddr isAlive bool - startup bool congestion Congestion v proxy.MacVerifier @@ -105,52 +106,82 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error { } f.writer = conn - f.startup = true // prod the connection once a second until we get an ack, then consider it alive - go func() { - seq, err := f.congestion.Sequence(ctx) - if err != nil { - return - } + var exchanges []proxy.Exchange - for !f.isAlive { - if ctx.Err() != nil { - return - } - - p := Packet{ - ack: 0, - nack: 0, - seq: seq, - data: proxy.SimplePacket(nil), - } - - _ = f.sendPacket(p, f.g) - time.Sleep(1 * time.Second) - } - }() - - go func() { - _, _ = f.produceInternal(ctx, f.v, false) - }() - go f.earlyUpdateLoop(ctx, f.g, f.keepalive) - - if err := f.readQueuePacket(ctx, conn); err != nil { - return err + if e, ok := f.congestion.(proxy.Exchange); ok { + exchanges = append(exchanges, e) } - f.isAlive = true - f.startup = false + var exchangeData [][]byte + + for _, e := range exchanges { + i, err := e.Initial(ctx) + if err != nil { + return err + } + + if err = f.sendPacket(proxy.SimplePacket(i), f.g); err != nil { + return err + } + + for once := true; once || !e.Complete(); once = false { + if err := func() error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var recv []byte + if recv, err = f.readPacket(ctx, conn); err != nil { + return err + } + + var resp, data []byte + if resp, data, err = e.Handle(ctx, recv); err != nil { + return err + } + + if data != nil { + exchangeData = append(exchangeData, data) + } + + if resp != nil { + if err = f.sendPacket(proxy.SimplePacket(resp), f.g); err != nil { + return err + } + } + + return nil + }(); err != nil { + return err + } + } + } go func() { + for _, d := range exchangeData { + if err := f.queueDatagram(ctx, d); err != nil { + return + } + } + lockedAccept := func() { f.mu.RLock() defer f.mu.RUnlock() - if err := f.readQueuePacket(ctx, conn); err != nil { + var p []byte + if p, err = f.readPacket(ctx, conn); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return + } log.Println(err) + return } + + if err := f.queueDatagram(ctx, p); err != nil { + return + } + } for f.isAlive { @@ -160,6 +191,7 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error { log.Println("no longer alive") }() + f.isAlive = true return nil } @@ -260,7 +292,7 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { } } -func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { +func (f *Flow) sendPacket(p proxy.Packet, g proxy.MacGenerator) error { b := p.Marshal() b = proxy.AppendMac(b, g) @@ -294,13 +326,24 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepal } } -func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { - // TODO: Replace 6000 with MTU+header size +func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) { buf := make([]byte, 6000) - n, _, err := c.ReadFromUDP(buf) - if err != nil { - return err + + if d, ok := ctx.Deadline(); ok { + if err := c.SetReadDeadline(d); err != nil { + return nil, err + } } - return f.queueDatagram(ctx, buf[:n]) + n, _, err := c.ReadFromUDP(buf) + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + if ctx.Err() != nil { + return nil, ctx.Err() + } + } + return nil, err + } + + return buf[:n], nil }