diff --git a/main.go b/main.go index d7f2633..493adad 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,14 @@ func main() { log.Println("loading config...") - c, err := config.LoadConfig("config.ini") + var configLoc string + if v, ok := os.LookupEnv("CONFIG_LOC"); ok { + configLoc = v + } else { + configLoc = "config.ini" + } + + c, err := config.LoadConfig(configLoc) if err != nil { panic(err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index e98a836..2bea389 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -75,13 +75,13 @@ func (p Proxy) AddConsumer(c Consumer) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", c) + log.Printf("attempting to connect consumer `%v`\n", c) err = c.(Reconnectable).Reconnect() if !once { time.Sleep(time.Second) } } - log.Printf("connected `%v`\n", c) + log.Printf("connected consumer `%v`\n", c) } for c.IsAlive() { @@ -92,7 +92,7 @@ func (p Proxy) AddConsumer(c Consumer) { } } - log.Printf("closed connection `%v`\n", c) + log.Printf("closed consumer `%v`\n", c) }() } @@ -104,13 +104,13 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", pr) + log.Printf("attempting to connect producer `%v`\n", pr) err = pr.(Reconnectable).Reconnect() if !once { time.Sleep(time.Second) } } - log.Printf("connected `%v`\n", pr) + log.Printf("connected producer `%v`\n", pr) } for pr.IsAlive() { @@ -123,6 +123,6 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { } } - log.Printf("closed connection `%v`\n", pr) + log.Printf("closed producer `%v`\n", pr) }() } diff --git a/tcp/flow.go b/tcp/flow.go index bf399bd..c626e49 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -88,10 +88,6 @@ func (f *InitiatedFlow) Reconnect() error { return nil } -func (f *Flow) IsAlive() bool { - return f.isAlive -} - func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() @@ -99,6 +95,17 @@ func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { return f.Flow.Consume(p, g) } +func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(v) +} + +func (f *Flow) IsAlive() bool { + return f.isAlive +} + func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { if !f.isAlive { return shared.ErrDeadConnection @@ -127,13 +134,6 @@ func (f *Flow) consumeMarshalled(data []byte) error { return err } -func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { - f.mu.RLock() - defer f.mu.RUnlock() - - return f.Flow.Produce(v) -} - func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection diff --git a/tcp/listener.go b/tcp/listener.go index 5da9761..52892e0 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -31,7 +31,7 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { f := Flow{conn: conn, isAlive: true} - log.Printf("received new connection: %v\n", f) + log.Printf("received new tcp connection: %v\n", f) p.AddConsumer(&f) p.AddProducer(&f, v) diff --git a/udp/congestion.go b/udp/congestion.go index 1d4d821..b3c593b 100644 --- a/udp/congestion.go +++ b/udp/congestion.go @@ -1,5 +1,7 @@ package udp +import "time" + type Congestion interface { Sequence() uint32 ReceivedPacket(seq uint32) @@ -9,4 +11,6 @@ type Congestion interface { ReceivedNack(uint32) NextNack() uint32 + + AwaitEarlyUpdate(keepalive time.Duration) } diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go index 3676b39..1414077 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno.go @@ -1,20 +1,33 @@ package congestion import ( + "math" "mpbl3p/utils" + "sync/atomic" "time" ) +const RttExponentialFactor = 0.1 + type NewReno struct { - sequence chan uint32 - packetTimes map[uint32]time.Time + sequence chan uint32 + keepalive chan bool - nextAck uint32 - nextNack uint32 + outboundTimes map[uint32]time.Time + inboundTimes map[uint32]time.Time - fastStart bool - windowSize uint - rtt time.Duration + ack, lastAck uint32 + nack, lastNack uint32 + + slowStart bool + rtt float64 + windowSize int32 + windowCount int32 + inFlight int32 + + ackNotifier chan struct{} + + lastSent time.Time acksToSend utils.Uint32Heap } @@ -22,8 +35,14 @@ type NewReno struct { func NewNewReno() *NewReno { c := NewReno{ sequence: make(chan uint32), - packetTimes: make(map[uint32]time.Time), - windowSize: 1, + ackNotifier: make(chan struct{}), + + outboundTimes: make(map[uint32]time.Time), + inboundTimes: make(map[uint32]time.Time), + + windowSize: 1, + rtt: (1 * time.Millisecond).Seconds(), + slowStart: true, } go func() { @@ -45,41 +64,113 @@ func NewNewReno() *NewReno { func (c *NewReno) ReceivedAck(ack uint32) { // RTT // Update using an exponential average + rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds() + delete(c.outboundTimes, ack) + c.rtt = c.rtt*(1-RttExponentialFactor) + rtt*RttExponentialFactor + // Free Window + atomic.AddInt32(&c.inFlight, -1) + select { + case c.ackNotifier <- struct{}{}: + default: + } // GROW // CASE: exponential. increase window size by one per ack // CASE: standard. increase window size by one per window of acks + if c.slowStart { + atomic.AddInt32(&c.windowSize, 1) + } else { + c.windowCount++ + if c.windowCount == c.windowSize { + c.windowCount = 0 + atomic.AddInt32(&c.windowSize, 1) + } + } } // It is assumed that ReceivedNack will only be called by one thread func (c *NewReno) ReceivedNack(nack uint32) { - // Back off + // End slow start + c.slowStart = false + if s := c.windowSize; s > 1 { + atomic.StoreInt32(&c.windowSize, s/2) + } } func (c *NewReno) ReceivedPacket(seq uint32) { + c.inboundTimes[seq] = time.Now() c.acksToSend.Insert(seq) - ack, err := c.acksToSend.Extract() - if err != nil { - panic(err) + findAck := func(start uint32) uint32 { + ack := start + for len(c.acksToSend) > 0 { + if a, _ := c.acksToSend.Peek(); a == ack+1 { + ack, _ = c.acksToSend.Extract() + } else { + break + } + } + return ack } - for a, _ := c.acksToSend.Peek(); a == ack+1; { - ack, _ = c.acksToSend.Extract() + ack := findAck(c.ack) + if ack == c.ack { + // check if there is a nack to send + // decide this based on whether there have been 3RTTs between the offset packet + if len(c.acksToSend) > 0 { + nextAck, _ := c.acksToSend.Peek() + if time.Now().Sub(c.inboundTimes[nextAck]).Seconds() > c.rtt*3 { + atomic.StoreUint32(&c.nack, nextAck-1) + ack, _ = c.acksToSend.Extract() + ack = findAck(ack) + } + } } + + atomic.StoreUint32(&c.ack, ack) } func (c *NewReno) Sequence() uint32 { + for c.inFlight >= c.windowSize { + <-c.ackNotifier + } + atomic.AddInt32(&c.inFlight, 1) + s := <-c.sequence - c.packetTimes[s] = time.Now() + + n := time.Now() + c.lastSent = n + c.outboundTimes[s] = n + return s } func (c *NewReno) NextAck() uint32 { - return c.nextAck + a := c.ack + c.lastAck = a + return a } func (c *NewReno) NextNack() uint32 { - return c.nextNack + n := c.nack + c.lastNack = n + return n +} + +func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) { + for { + rtt := time.Duration(math.Round(c.rtt * float64(time.Second))) + time.Sleep(rtt) + + // CASE 1: > 5 waiting ACKs or any waiting NACKs and no message sent in the last RTT + if (c.lastAck-c.ack) > 5 || (c.lastNack != c.nack) && time.Now().After(c.lastSent.Add(rtt)) { + return + } + + // CASE 3: No message sent within the keepalive time + if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { + return + } + } } diff --git a/udp/flow.go b/udp/flow.go index 359a46c..4009876 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -1,13 +1,20 @@ package udp import ( + "errors" + "fmt" + "log" + "mpbl3p/config" "mpbl3p/proxy" + "mpbl3p/shared" "net" "sync" + "time" ) type PacketWriter interface { WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + LocalAddr() net.Addr } type PacketConn interface { @@ -16,13 +23,20 @@ type PacketConn interface { } type InitiatedFlow struct { - Local string + Local string Remote string + g proxy.MacGenerator + keepalive time.Duration + mu sync.RWMutex Flow } +func (f *InitiatedFlow) String() string { + return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local, f.Remote) +} + type Flow struct { writer PacketWriter raddr net.UDPAddr @@ -33,34 +47,107 @@ type Flow struct { inboundDatagrams chan []byte } -func newOutboundFlow(c Congestion) *Flow { - return &Flow{ - congestion: c, - inboundDatagrams: make(chan []byte), - } +func (f Flow) String() string { + return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) } -func newInboundFlow(c Congestion) *Flow { - return &Flow{ - congestion: c, - inboundDatagrams: make(chan []byte), - } -} - -func InitiateFlow(local, remote string) (*InitiatedFlow, error) { +func InitiateFlow( + local, remote string, + g proxy.MacGenerator, + c Congestion, + keepalive time.Duration, +) (*InitiatedFlow, error) { f := InitiatedFlow{ - Local: local, - Remote: remote, + Local: local, + Remote: remote, + Flow: newFlow(c), + g: g, + keepalive: keepalive, } return &f, nil } +func newFlow(c Congestion) Flow { + return Flow{ + inboundDatagrams: make(chan []byte), + congestion: c, + } +} + +func (f *InitiatedFlow) Reconnect() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.isAlive { + return nil + } + + localAddr, err := net.ResolveUDPAddr("udp", f.Local) + if err != nil { + return err + } + + remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote) + if err != nil { + return err + } + + conn, err := net.DialUDP("udp", localAddr, remoteAddr) + if err != nil { + return err + } + + f.writer = conn + f.isAlive = true + + go func() { + for { + buf := make([]byte, 6000) + n, _, err := conn.ReadFromUDP(buf) + if err != nil { + panic(err) + } + + f.inboundDatagrams <- buf[:n] + } + }() + + go func() { + var err error + for !errors.Is(err, shared.ErrDeadConnection) { + f.congestion.AwaitEarlyUpdate(f.keepalive) + err = f.Consume(proxy.NewSimplePacket(nil), f.g) + } + }() + + return nil +} + +func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Consume(p, g) +} + +func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(v) +} + func (f *Flow) IsAlive() bool { return f.isAlive } func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { + if !f.isAlive { + return shared.ErrDeadConnection + } + + // Sequence is the congestion controllers opportunity to block p := Packet{ seq: f.congestion.Sequence(), data: pp, @@ -77,6 +164,10 @@ func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { } func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + if !f.isAlive { + return nil, shared.ErrDeadConnection + } + b, err := proxy.StripMac(<-f.inboundDatagrams, v) if err != nil { return nil, err @@ -103,5 +194,30 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { } func (f *Flow) handleDatagram(p []byte) { - f.inboundDatagrams <- p + // TODO: Fix with security + // 12 bytes for header + the MAC + a timestamp + if len(p) == 12+(config.UselessMac{}).CodeLength()+8 { + b, err := proxy.StripMac(<-f.inboundDatagrams, config.UselessMac{}) + if err != nil { + log.Println(err) + return + } + + p, err := UnmarshalPacket(b) + if err != nil { + log.Println(err) + return + } + + // TODO: Decide whether to use this line. It means an ACK loop will start, but also is a packet loss. + f.congestion.ReceivedPacket(p.seq) + if p.ack != 0 { + f.congestion.ReceivedAck(p.ack) + } + if p.nack != 0 { + f.congestion.ReceivedNack(p.nack) + } + } else { + f.inboundDatagrams <- p + } } diff --git a/udp/listener.go b/udp/listener.go index 2ad59ff..1b7119e 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -1,8 +1,10 @@ package udp import ( + "errors" "log" "mpbl3p/proxy" + "mpbl3p/shared" "net" ) @@ -25,7 +27,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { } } -func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { +func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c Congestion) error { laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return err @@ -58,15 +60,23 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { continue } - f := Flow{ - writer: pconn, - raddr: *addr, - isAlive: true, - } + f := newFlow(c) + + f.writer = pconn + f.raddr = *addr + f.isAlive = true + + go func() { + var err error + for !errors.Is(err, shared.ErrDeadConnection) { + f.congestion.AwaitEarlyUpdate(0) + err = f.Consume(proxy.NewSimplePacket(nil), g) + } + }() receivedConnections[raddr] = &f - log.Printf("received new connection: %v\n", f) + log.Printf("received new udp connection: %v\n", f) p.AddConsumer(&f) p.AddProducer(&f, v) diff --git a/utils/heap.go b/utils/heap.go index ccadc0f..96f2c34 100644 --- a/utils/heap.go +++ b/utils/heap.go @@ -5,23 +5,21 @@ import "errors" var ErrorEmptyHeap = errors.New("attempted to extract from empty heap") // A MinHeap for Uint64 -type Uint32Heap struct { - data []uint32 -} +type Uint32Heap []uint32 func (h *Uint32Heap) swap(x, y int) { - h.data[x] = h.data[x] ^ h.data[y] - h.data[y] = h.data[y] ^ h.data[x] - h.data[x] = h.data[x] ^ h.data[y] + (*h)[x] = (*h)[x] ^ (*h)[y] + (*h)[y] = (*h)[y] ^ (*h)[x] + (*h)[x] = (*h)[x] ^ (*h)[y] } func (h *Uint32Heap) Insert(new uint32) uint32 { - h.data = append(h.data, new) + *h = append(*h, new) - child := len(h.data) - 1 + child := len(*h) - 1 for child != 0 { parent := (child - 1) / 2 - if h.data[parent] > h.data[child] { + if (*h)[parent] > (*h)[child] { h.swap(parent, child) } else { break @@ -29,24 +27,24 @@ func (h *Uint32Heap) Insert(new uint32) uint32 { child = parent } - return h.data[0] + return (*h)[0] } func (h *Uint32Heap) Extract() (uint32, error) { - if len(h.data) == 0 { + if len(*h) == 0 { return 0, ErrorEmptyHeap } - min := h.data[0] + min := (*h)[0] - h.data[0] = h.data[len(h.data)-1] - h.data = h.data[:len(h.data)-1] + (*h)[0] = (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] parent := 0 for { left, right := parent*2+1, parent*2+2 - if (left < len(h.data) && h.data[parent] > h.data[left]) || (right < len(h.data) && h.data[parent] > h.data[right]) { - if right < len(h.data) && h.data[left] > h.data[right] { + if (left < len(*h) && (*h)[parent] > (*h)[left]) || (right < len(*h) && (*h)[parent] > (*h)[right]) { + if right < len(*h) && (*h)[left] > (*h)[right] { h.swap(parent, right) parent = right } else { @@ -60,8 +58,8 @@ func (h *Uint32Heap) Extract() (uint32, error) { } func (h *Uint32Heap) Peek() (uint32, error) { - if len(h.data) == 0 { + if len(*h) == 0 { return 0, ErrorEmptyHeap } - return h.data[0], nil + return (*h)[0], nil } diff --git a/utils/heap_test.go b/utils/heap_test.go index d0393f0..61a1bbd 100644 --- a/utils/heap_test.go +++ b/utils/heap_test.go @@ -8,8 +8,8 @@ import ( "time" ) -func SlowHeapSort(in []uint64) []uint64 { - out := make([]uint64, len(in)) +func SlowHeapSort(in []uint32) []uint32 { + out := make([]uint32, len(in)) var heap Uint32Heap @@ -27,16 +27,16 @@ func SlowHeapSort(in []uint64) []uint64 { return out } -func TestUint64Heap(t *testing.T) { +func TestUint32Heap(t *testing.T) { t.Run("EquivalentToMerge", func(t *testing.T) { const ArrayLength = 50 - sortedArray := make([]uint64, ArrayLength) - array := make([]uint64, ArrayLength) + sortedArray := make([]uint32, ArrayLength) + array := make([]uint32, ArrayLength) for i := range array { - sortedArray[i] = uint64(i) - array[i] = uint64(i) + sortedArray[i] = uint32(i) + array[i] = uint32(i) } rand.Seed(time.Now().Unix())