From 5e0dc6969d027ddccf6b578cfc83f61495205d92 Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Fri, 27 Nov 2020 17:31:32 +0000 Subject: [PATCH] udp expansions --- udp/congestion.go | 4 +- udp/congestion/newreno.go | 56 ++++++++++-- udp/congestion/none.go | 6 +- udp/flow.go | 175 ++++++++++++++++++++++-------------- udp/listener.go | 29 +++--- udp/wireshark_dissector.lua | 39 ++++++++ 6 files changed, 221 insertions(+), 88 deletions(-) create mode 100644 udp/wireshark_dissector.lua diff --git a/udp/congestion.go b/udp/congestion.go index b3c593b..826910e 100644 --- a/udp/congestion.go +++ b/udp/congestion.go @@ -12,5 +12,7 @@ type Congestion interface { ReceivedNack(uint32) NextNack() uint32 - AwaitEarlyUpdate(keepalive time.Duration) + AwaitEarlyUpdate(keepalive time.Duration) uint32 + AwaitAck(timeout time.Duration) bool + Reset() } diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go index ab69477..cc12eeb 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno.go @@ -1,8 +1,11 @@ package congestion import ( + "fmt" + "log" "math" "mpbl3p/utils" + "sync" "sync/atomic" "time" ) @@ -28,8 +31,15 @@ type NewReno struct { ackNotifier chan struct{} lastSent time.Time + hasAcked bool acksToSend utils.Uint32Heap + + mu sync.Mutex +} + +func (c *NewReno) String() string { + return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, c.inFlight, c.lastAck, c.lastNack) } func NewNewReno() *NewReno { @@ -40,7 +50,7 @@ func NewNewReno() *NewReno { outboundTimes: make(map[uint32]time.Time), inboundTimes: make(map[uint32]time.Time), - windowSize: 1, + windowSize: 8, rtt: (1 * time.Millisecond).Seconds(), slowStart: true, } @@ -61,8 +71,20 @@ func NewNewReno() *NewReno { return &c } +func (c *NewReno) Reset() { + c.outboundTimes = make(map[uint32]time.Time) + c.inboundTimes = make(map[uint32]time.Time) + c.windowSize = 8 + c.rtt = (1 * time.Millisecond).Seconds() + c.slowStart = true + c.hasAcked = false +} + // It is assumed that ReceivedAck will only be called by one thread func (c *NewReno) ReceivedAck(ack uint32) { + log.Printf("ack received for %d", ack) + c.hasAcked = true + // RTT // Update using an exponential average rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds() @@ -92,6 +114,8 @@ func (c *NewReno) ReceivedAck(ack uint32) { // It is assumed that ReceivedNack will only be called by one thread func (c *NewReno) ReceivedNack(nack uint32) { + log.Printf("nack received for %d", nack) + // End slow start c.slowStart = false if s := c.windowSize; s > 1 { @@ -100,8 +124,20 @@ func (c *NewReno) ReceivedNack(nack uint32) { } func (c *NewReno) ReceivedPacket(seq uint32) { + log.Printf("seq received for %d", seq) + c.inboundTimes[seq] = time.Now() + + c.mu.Lock() c.acksToSend.Insert(seq) + c.mu.Unlock() + + c.updateAckNack() +} + +func (c *NewReno) updateAckNack() { + c.mu.Lock() + defer c.mu.Unlock() findAck := func(start uint32) uint32 { ack := start @@ -159,19 +195,29 @@ func (c *NewReno) NextNack() uint32 { return n } -func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) { +func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 { for { rtt := time.Duration(math.Round(c.rtt * float64(time.Second))) time.Sleep(rtt) + c.updateAckNack() + // CASE 1: > 5 waiting ACKs or any waiting NACKs and no message sent in the last RTT - if (c.lastAck-c.ack) > 5 || (c.lastNack != c.nack) && time.Now().After(c.lastSent.Add(rtt)) { - return + if ((c.lastAck!=c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt)) { + return 0 } // CASE 3: No message sent within the keepalive time if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { - return + return c.Sequence() } } } + +func (c *NewReno) AwaitAck(timeout time.Duration) bool { + if c.hasAcked { + return true + } + time.Sleep(timeout) + return c.hasAcked +} diff --git a/udp/congestion/none.go b/udp/congestion/none.go index 5cf6cd4..1089dc2 100644 --- a/udp/congestion/none.go +++ b/udp/congestion/none.go @@ -31,11 +31,13 @@ func (c *None) Sequence() uint32 { return <-c.sequence } -func (c *None) ReceivedPacket(seq uint32) {} +func (c *None) ReceivedPacket(uint32) {} func (c *None) ReceivedAck(uint32) {} func (c *None) NextAck() uint32 { return 0 } func (c *None) ReceivedNack(uint32) {} func (c *None) NextNack() uint32 { return 0 } -func (c *None) AwaitEarlyUpdate(keepalive time.Duration) { +func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} } +func (c *None) AwaitAck(time.Duration) bool { return true } +func (c *None) Reset() {} diff --git a/udp/flow.go b/udp/flow.go index 342f1da..cebd8ad 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -42,6 +42,7 @@ type Flow struct { raddr *net.UDPAddr isAlive bool + startup bool congestion Congestion v proxy.MacVerifier @@ -81,7 +82,6 @@ func newFlow(c Congestion, v proxy.MacVerifier) Flow { func (f *InitiatedFlow) Reconnect() error { f.mu.Lock() - defer f.mu.Unlock() if f.isAlive { return nil @@ -103,28 +103,55 @@ func (f *InitiatedFlow) Reconnect() error { } f.writer = conn - f.isAlive = true + f.startup = true + // prod the connection once a second until we get an ack, then consider it alive go func() { - for { - buf := make([]byte, 6000) - n, _, err := conn.ReadFromUDP(buf) - if err != nil { - panic(err) + seq := f.congestion.Sequence() + + defer f.mu.Unlock() + for !f.isAlive { + p := Packet{ + ack: 0, + nack: 0, + seq: seq, + data: proxy.NewSimplePacket(nil), } - f.inboundDatagrams <- buf[:n] + _ = f.sendPacket(p, f.g) + + if f.congestion.AwaitAck(1 * time.Second) { + f.isAlive = true + f.startup = false + } } }() go func() { - var err error - for !errors.Is(err, shared.ErrDeadConnection) { - f.congestion.AwaitEarlyUpdate(f.keepalive) - err = f.Consume(proxy.NewSimplePacket(nil), f.g) + for f.startup || f.isAlive { + func() { + if f.isAlive { + f.mu.RLock() + defer f.mu.RUnlock() + } + + buf := make([]byte, 6000) + n, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Println(err) + time.Sleep(1) + } else { + f.handleDatagram(buf[:n]) + } + }() } }() + go func() { + _, _ = f.produceInternal(f.v, false) + }() + go f.earlyUpdateLoop(f.g, f.keepalive) + return nil } @@ -151,15 +178,75 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { return shared.ErrDeadConnection } + log.Println(f.congestion) + // Sequence is the congestion controllers opportunity to block + log.Println("awaiting sequence") p := Packet{ seq: f.congestion.Sequence(), data: pp, } + log.Println("received sequence") + // Choose up to date ACK/NACK even after blocking p.ack = f.congestion.NextAck() p.nack = f.congestion.NextNack() + return f.sendPacket(p, g) +} + +func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + if !f.isAlive { + return nil, shared.ErrDeadConnection + } + + return f.produceInternal(v, true) +} + +func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { + for once := true; mustReturn || once; once = false { + log.Println(f.congestion) + + b, err := proxy.StripMac(<-f.inboundDatagrams, v) + if err != nil { + return nil, err + } + + p, err := UnmarshalPacket(b) + if err != nil { + return nil, err + } + + // schedule an ack for this sequence number + if p.seq != 0 { + f.congestion.ReceivedPacket(p.seq) + } + // adjust our sending congestion control based on their acks + if p.ack != 0 { + f.congestion.ReceivedAck(p.ack) + } + // adjust our sending congestion control based on their nacks + if p.nack != 0 { + f.congestion.ReceivedNack(p.nack) + } + + // 12 bytes for header + the MAC + a timestamp + if len(b) == 12+f.v.CodeLength()+8 { + log.Println("handled keepalive/ack only packet") + continue + } + + return p, nil + } + + return nil, nil +} + +func (f *Flow) handleDatagram(p []byte) { + f.inboundDatagrams <- p +} + +func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { b := p.Marshal() b = proxy.AppendMac(b, g) @@ -172,61 +259,17 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { } } -func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { - if !f.isAlive { - return nil, shared.ErrDeadConnection - } - - b, err := proxy.StripMac(<-f.inboundDatagrams, v) - if err != nil { - return nil, err - } - - p, err := UnmarshalPacket(b) - if err != nil { - return nil, err - } - - // schedule an ack for this sequence number - f.congestion.ReceivedPacket(p.seq) - - // adjust our sending congestion control based on their acks - if p.ack != 0 { - f.congestion.ReceivedAck(p.ack) - } - // adjust our sending congestion control based on their nacks - if p.nack != 0 { - f.congestion.ReceivedNack(p.nack) - } - - return p, nil -} - -func (f *Flow) handleDatagram(p []byte) { - // TODO: Fix with security - // 12 bytes for header + the MAC + a timestamp - if len(p) == 12+f.v.CodeLength()+8 { - b, err := proxy.StripMac(<-f.inboundDatagrams, f.v) - if err != nil { - log.Println(err) - return +func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) { + var err error + for !errors.Is(err, shared.ErrDeadConnection) { + seq := f.congestion.AwaitEarlyUpdate(keepalive) + p := Packet{ + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), + seq: seq, + data: proxy.NewSimplePacket(nil), } - p, err := UnmarshalPacket(b) - if err != nil { - log.Println(err) - return - } - - // TODO: Decide whether to use this line. It means an ACK loop will start, but also is a packet loss. - f.congestion.ReceivedPacket(p.seq) - if p.ack != 0 { - f.congestion.ReceivedAck(p.ack) - } - if p.nack != 0 { - f.congestion.ReceivedNack(p.nack) - } - } else { - f.inboundDatagrams <- p + _ = f.sendPacket(p, g) } } diff --git a/udp/listener.go b/udp/listener.go index 3abc810..b93aea7 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -1,10 +1,8 @@ package udp import ( - "errors" "log" "mpbl3p/proxy" - "mpbl3p/shared" "net" ) @@ -47,16 +45,21 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacG go func() { for { - buf := make([]byte, 1500) + buf := make([]byte, 6000) - _, addr, err := pconn.ReadFromUDP(buf) + log.Println("listening...") + n, addr, err := pconn.ReadFromUDP(buf) if err != nil { panic(err) } + log.Println("listened") raddr := fromUdpAddress(*addr) if f, exists := receivedConnections[raddr]; exists { - f.handleDatagram(buf) + log.Println("existing flow") + log.Println("handling...") + f.handleDatagram(buf[:n]) + log.Println("handled") continue } @@ -66,20 +69,18 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacG f.raddr = addr f.isAlive = true - go func() { - var err error - for !errors.Is(err, shared.ErrDeadConnection) { - f.congestion.AwaitEarlyUpdate(0) - err = f.Consume(proxy.NewSimplePacket(nil), g) - } - }() + log.Printf("received new udp connection: %v\n", f) + + go f.earlyUpdateLoop(g, 0) receivedConnections[raddr] = &f - log.Printf("received new udp connection: %v\n", f) - p.AddConsumer(&f) p.AddProducer(&f, v) + + log.Println("handling...") + f.handleDatagram(buf[:n]) + log.Println("handled") } }() diff --git a/udp/wireshark_dissector.lua b/udp/wireshark_dissector.lua new file mode 100644 index 0000000..f889e03 --- /dev/null +++ b/udp/wireshark_dissector.lua @@ -0,0 +1,39 @@ +local ip_dissector = Dissector.get("ip") + +mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP") + +ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement") +nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement") +seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number") +time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp") +proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data") + +mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F } + +function mpbl3p_udp.dissector(buffer, pinfo, tree) + if buffer:len() < 20 then + return + end + + pinfo.cols.protocol = "MPBL3P_UDP" + + local ack = buffer(0, 4):le_uint() + local nack = buffer(4, 4):le_uint() + local seq = buffer(8, 4):le_uint() + + local unix_time = buffer(buffer:len() - 8, 8):le_uint64() + + local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack) + + subtree:add(ack_F, ack) + subtree:add(nack_F, nack) + subtree:add(seq_F, seq) + subtree:add(time_F, NSTime.new(unix_time:tonumber())) + if buffer:len() > 20 then + subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8)) + end + + --Dissector.call(buffer(12, buffer:len()-12-8), pinfo, tree) +end + +DissectorTable.get("udp.port"):add(1234, mpbl3p_udp)