diff --git a/config/builder.go b/config/builder.go index 76bcb0a..971cb99 100644 --- a/config/builder.go +++ b/config/builder.go @@ -11,6 +11,7 @@ import ( "mpbl3p/tcp" "mpbl3p/udp" "mpbl3p/udp/congestion" + "mpbl3p/udp/congestion/newreno" "time" ) @@ -127,7 +128,7 @@ func buildUdp( default: fallthrough case "NewReno": - c = func() udp.Congestion { return congestion.NewNewReno() } + c = func() udp.Congestion { return newreno.NewNewReno() } } if peer.RemoteHost != "" { diff --git a/flags/flags.go b/flags/flags.go index 236d72a..4f6606d 100644 --- a/flags/flags.go +++ b/flags/flags.go @@ -1,14 +1,12 @@ package flags import ( - "errors" "fmt" goflags "github.com/jessevdk/go-flags" "os" ) var PrintedHelpErr = goflags.ErrHelp -var NotEnoughArgs = errors.New("not enough arguments") type Options struct { Foreground bool `short:"f" long:"foreground" description:"Run in the foreground"` diff --git a/flags/locs_darwin.go b/flags/locs_darwin.go new file mode 100644 index 0000000..13bbda0 --- /dev/null +++ b/flags/locs_darwin.go @@ -0,0 +1,4 @@ +package flags + +const DefaultConfigFile = "/usr/local/etc/netcombiner/%s" +const DefaultPidFile = "/var/run/netcombiner/%s.pid" diff --git a/mocks/packetconn.go b/mocks/packetconn.go index 9360db4..9f79e10 100644 --- a/mocks/packetconn.go +++ b/mocks/packetconn.go @@ -1,6 +1,9 @@ package mocks -import "net" +import ( + "net" + "time" +) type MockPerfectBiPacketConn struct { directionA chan []byte @@ -44,6 +47,10 @@ func (c MockPerfectPacketConn) LocalAddr() net.Addr { } } +func (c MockPerfectPacketConn) SetReadDeadline(time.Time) error { + return nil +} + func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { p := <-c.inbound return copy(b, p), &net.UDPAddr{ diff --git a/proxy/exchange.go b/proxy/exchange.go new file mode 100644 index 0000000..dae3bb6 --- /dev/null +++ b/proxy/exchange.go @@ -0,0 +1,9 @@ +package proxy + +import "context" + +type Exchange interface { + Initial(ctx context.Context) (out []byte, err error) + Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) + Complete() bool +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 1bc535c..339f676 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -75,12 +75,13 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect consumer `%v`\n", c) - err = c.(Reconnectable).Reconnect(ctx) - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if err := ctx.Err(); err != nil { log.Printf("closed consumer `%v` (context)\n", c) return } + + log.Printf("attempting to connect consumer `%v`\n", c) + err = c.(Reconnectable).Reconnect(ctx) if !once { time.Sleep(time.Second) } @@ -118,12 +119,13 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect producer `%v`\n", pr) - err = pr.(Reconnectable).Reconnect(ctx) - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if err := ctx.Err(); err != nil { log.Printf("closed producer `%v` (context)\n", pr) return } + + log.Printf("attempting to connect producer `%v`\n", pr) + err = pr.(Reconnectable).Reconnect(ctx) if !once { time.Sleep(time.Second) } diff --git a/shared/errors.go b/shared/errors.go index 0db0c92..ec5b537 100644 --- a/shared/errors.go +++ b/shared/errors.go @@ -5,3 +5,4 @@ import "errors" var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrDeadConnection = errors.New("the connection is dead") var ErrNotEnoughBytes = errors.New("not enough bytes") +var ErrBadExchange = errors.New("bad exchange") diff --git a/udp/congestion/newreno/exchange.go b/udp/congestion/newreno/exchange.go new file mode 100644 index 0000000..29ca16a --- /dev/null +++ b/udp/congestion/newreno/exchange.go @@ -0,0 +1,119 @@ +package newreno + +import ( + "context" + "encoding/binary" + "math/rand" + "mpbl3p/shared" + "time" +) + +func (c *NewReno) Initial(ctx context.Context) (out []byte, err error) { + c.alive = false + c.wasInitial = true + c.startSequenceLoop(ctx) + + var s uint32 + select { + case s = <-c.sequence: + case <-ctx.Done(): + return nil, ctx.Err() + } + + b := make([]byte, 12) + binary.LittleEndian.PutUint32(b[8:12], s) + + c.inFlight = []flightInfo{{time.Now(), s}} + + return b, nil +} + +func (c *NewReno) Handle(ctx context.Context, in []byte) (out []byte, data []byte, err error) { + if c.alive || c.stopSequence == nil { + // reset + c.alive = false + c.startSequenceLoop(ctx) + } + + // receive + if len(in) != 12 { + return nil, nil, shared.ErrBadExchange + } + + rcvAck := binary.LittleEndian.Uint32(in[0:4]) + rcvNack := binary.LittleEndian.Uint32(in[4:8]) + rcvSeq := binary.LittleEndian.Uint32(in[8:12]) + + // verify + if rcvNack != 0 { + return nil, nil, shared.ErrBadExchange + } + + var seq uint32 + + if c.wasInitial { + if rcvAck == c.inFlight[0].sequence { + c.ack, c.lastAck = rcvSeq, rcvSeq + c.alive, c.inFlight = true, nil + } else { + return nil, nil, shared.ErrBadExchange + } + } else { // if !c.wasInitial + if rcvAck == 0 { + // theirs is a syn packet + c.ack, c.lastAck = rcvSeq, rcvSeq + + select { + case seq = <-c.sequence: + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + + c.inFlight = []flightInfo{{time.Now(), seq}} + } else if len(c.inFlight) == 1 && rcvAck == c.inFlight[0].sequence { + c.alive, c.inFlight = true, nil + } else { + return nil, nil, shared.ErrBadExchange + } + } + + // respond + b := make([]byte, 12) + binary.LittleEndian.PutUint32(b[0:4], c.ack) + binary.LittleEndian.PutUint32(b[8:12], seq) + + return b, nil, nil +} + +func (c *NewReno) Complete() bool { + return c.alive +} + +func (c *NewReno) startSequenceLoop(ctx context.Context) { + if c.stopSequence != nil { + c.stopSequence() + } + + var s uint32 + for s == 0 { + s = rand.Uint32() + } + + ctx, c.stopSequence = context.WithCancel(ctx) + go func() { + s := s + for { + if s == 0 { + s++ + continue + } + + select { + case c.sequence <- s: + case <-ctx.Done(): + return + } + s++ + } + }() +} diff --git a/udp/congestion/newreno/exchange_test.go b/udp/congestion/newreno/exchange_test.go new file mode 100644 index 0000000..73b2a9f --- /dev/null +++ b/udp/congestion/newreno/exchange_test.go @@ -0,0 +1,57 @@ +package newreno + +import ( + "context" + "encoding/binary" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewReno_InitialHandle(t *testing.T) { + t.Run("InitialAckNackAreZero", func(t *testing.T) { + // ASSIGN + a := NewNewReno() + + ctx := context.Background() + + // ACT + initial, err := a.Initial(ctx) + + ack := binary.LittleEndian.Uint32(initial[0:4]) + nack := binary.LittleEndian.Uint32(initial[4:8]) + + // ASSERT + require.Nil(t, err) + + assert.Zero(t, ack) + assert.Zero(t, nack) + }) + + t.Run("InitialHandledWithAck", func(t *testing.T) { + // ASSIGN + a := NewNewReno() + b := NewNewReno() + + ctx := context.Background() + + // ACT + initial, err := a.Initial(ctx) + require.Nil(t, err) + + initialSeq := binary.LittleEndian.Uint32(initial[8:12]) + + response, data, err := b.Handle(ctx, initial) + + ack := binary.LittleEndian.Uint32(response[0:4]) + nack := binary.LittleEndian.Uint32(response[4:8]) + + // ASSERT + require.Nil(t, err) + + assert.Equal(t, initialSeq, ack) + assert.Zero(t, nack) + + assert.Nil(t, data) + }) +} diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno/newreno.go similarity index 96% rename from udp/congestion/newreno.go rename to udp/congestion/newreno/newreno.go index d796e11..a3c0e89 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno/newreno.go @@ -1,4 +1,4 @@ -package congestion +package newreno import ( "context" @@ -11,10 +11,12 @@ import ( ) const RttExponentialFactor = 0.1 -const RttLossDelay = 1.5 +const RttLossDelay = 0.5 type NewReno struct { - sequence chan uint32 + sequence chan uint32 + stopSequence context.CancelFunc + wasInitial, alive bool inFlight []flightInfo lastSent time.Time @@ -64,19 +66,6 @@ func NewNewReno() *NewReno { slowStart: true, } - go func() { - var s uint32 - for { - if s == 0 { - s++ - continue - } - - c.sequence <- s - s++ - } - }() - return &c } diff --git a/udp/congestion/newreno_test.go b/udp/congestion/newreno/newreno_test.go similarity index 78% rename from udp/congestion/newreno_test.go rename to udp/congestion/newreno/newreno_test.go index f97dfb9..5ebe91a 100644 --- a/udp/congestion/newreno_test.go +++ b/udp/congestion/newreno/newreno_test.go @@ -1,4 +1,4 @@ -package congestion +package newreno import ( "context" @@ -21,8 +21,8 @@ type newRenoTest struct { halfRtt time.Duration } -func newNewRenoTest(rtt time.Duration) *newRenoTest { - return &newRenoTest{ +func newNewRenoTest(ctx context.Context, rtt time.Duration) *newRenoTest { + nr := &newRenoTest{ sideA: NewNewReno(), sideB: NewNewReno(), @@ -34,6 +34,15 @@ func newNewRenoTest(rtt time.Duration) *newRenoTest { halfRtt: rtt / 2, } + + p, _ := nr.sideA.Initial(ctx) + p, _, _ = nr.sideB.Handle(ctx, p) + p, _, _ = nr.sideA.Handle(ctx, p) + + nr.sideB.ReceivedPacket(0, nr.sideA.NextAck(), nr.sideA.NextNack()) + nr.sideA.ReceivedPacket(0, nr.sideB.NextAck(), nr.sideB.NextNack()) + + return nr } func (n *newRenoTest) Start(ctx context.Context) { @@ -151,11 +160,14 @@ func TestNewReno_Congestion(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newNewRenoTest(rtt) + c := newNewRenoTest(ctx, rtt) c.Start(ctx) c.RunSideA(ctx) c.RunSideB(ctx) + sideAinitialAck := c.sideA.ack + sideBinitialAck := c.sideB.ack + // ACT for i := 0; i < numPackets; i++ { // sleep to simulate preparing packet @@ -175,10 +187,10 @@ func TestNewReno_Congestion(t *testing.T) { // ASSERT assert.Equal(t, uint32(0), c.sideA.nack) - assert.Equal(t, uint32(0), c.sideA.ack) + assert.Equal(t, sideAinitialAck, c.sideA.ack) assert.Equal(t, uint32(0), c.sideB.nack) - assert.Equal(t, uint32(numPackets), c.sideB.ack) + assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack) }) t.Run("SequenceLoss", func(t *testing.T) { @@ -189,18 +201,21 @@ func TestNewReno_Congestion(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newNewRenoTest(rtt) + c := newNewRenoTest(ctx, rtt) c.Start(ctx) c.RunSideA(ctx) c.RunSideB(ctx) + sideAinitialAck := c.sideA.ack + sideBinitialAck := c.sideB.ack + // ACT - for i := 0; i < numPackets; i++ { + for i := 1; i <= numPackets; i++ { // sleep to simulate preparing packet time.Sleep(1 * time.Millisecond) seq, _ := c.sideA.Sequence(ctx) - if seq == 20 { + if i == 20 { // Simulate packet loss of sequence 20 continue } @@ -217,10 +232,10 @@ func TestNewReno_Congestion(t *testing.T) { // ASSERT assert.Equal(t, uint32(0), c.sideA.nack) - assert.Equal(t, uint32(0), c.sideA.ack) + assert.Equal(t, sideAinitialAck, c.sideA.ack) - assert.Equal(t, uint32(20), c.sideB.nack) - assert.Equal(t, uint32(numPackets), c.sideB.ack) + assert.Equal(t, sideBinitialAck+uint32(20), c.sideB.nack) + assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack) }) }) @@ -233,16 +248,19 @@ func TestNewReno_Congestion(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newNewRenoTest(rtt) + c := newNewRenoTest(ctx, rtt) c.Start(ctx) c.RunSideA(ctx) c.RunSideB(ctx) + sideAinitialAck := c.sideA.ack + sideBinitialAck := c.sideB.ack + // ACT done := make(chan struct{}) go func() { - for i := 0; i < numPackets; i++ { + for i := 1; i <= numPackets; i++ { time.Sleep(1 * time.Millisecond) seq, _ := c.sideA.Sequence(ctx) @@ -257,7 +275,7 @@ func TestNewReno_Congestion(t *testing.T) { }() go func() { - for i := 0; i < numPackets; i++ { + for i := 1; i <= numPackets; i++ { time.Sleep(1 * time.Millisecond) seq, _ := c.sideB.Sequence(ctx) @@ -279,10 +297,10 @@ func TestNewReno_Congestion(t *testing.T) { // ASSERT assert.Equal(t, uint32(0), c.sideA.nack) - assert.Equal(t, uint32(numPackets), c.sideA.ack) + assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack) assert.Equal(t, uint32(0), c.sideB.nack) - assert.Equal(t, uint32(numPackets), c.sideB.ack) + assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack) }) t.Run("SequenceLoss", func(t *testing.T) { @@ -293,20 +311,23 @@ func TestNewReno_Congestion(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newNewRenoTest(rtt) + c := newNewRenoTest(ctx, rtt) c.Start(ctx) c.RunSideA(ctx) c.RunSideB(ctx) + sideAinitialAck := c.sideA.ack + sideBinitialAck := c.sideB.ack + // ACT done := make(chan struct{}) go func() { - for i := 0; i < numPackets; i++ { + for i := 1; i <= numPackets; i++ { time.Sleep(1 * time.Millisecond) seq, _ := c.sideA.Sequence(ctx) - if seq == 9 { + if i == 9 { // Simulate packet loss of sequence 9 continue } @@ -322,11 +343,11 @@ func TestNewReno_Congestion(t *testing.T) { }() go func() { - for i := 0; i < numPackets; i++ { + for i := 1; i <= numPackets; i++ { time.Sleep(1 * time.Millisecond) seq, _ := c.sideB.Sequence(ctx) - if seq == 13 { + if i == 13 { // Simulate packet loss of sequence 13 continue } @@ -348,11 +369,11 @@ func TestNewReno_Congestion(t *testing.T) { // ASSERT - assert.Equal(t, uint32(13), c.sideA.nack) - assert.Equal(t, uint32(numPackets), c.sideA.ack) + assert.Equal(t, sideAinitialAck+uint32(13), c.sideA.nack) + assert.Equal(t, sideAinitialAck+uint32(numPackets), c.sideA.ack) - assert.Equal(t, uint32(9), c.sideB.nack) - assert.Equal(t, uint32(numPackets), c.sideB.ack) + assert.Equal(t, sideBinitialAck+uint32(9), c.sideB.nack) + assert.Equal(t, sideBinitialAck+uint32(numPackets), c.sideB.ack) }) }) } diff --git a/udp/flow.go b/udp/flow.go index 6ef2666..f538b1d 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -7,7 +7,6 @@ import ( "mpbl3p/proxy" "mpbl3p/shared" "net" - "sync" "time" ) @@ -19,29 +18,15 @@ type PacketWriter interface { type PacketConn interface { PacketWriter + SetReadDeadline(t time.Time) error ReadFromUDP(b []byte) (int, *net.UDPAddr, error) } -type InitiatedFlow struct { - Local func() string - Remote string - - keepalive time.Duration - - mu sync.RWMutex - Flow -} - -func (f *InitiatedFlow) String() string { - return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote) -} - type Flow struct { writer PacketWriter raddr *net.UDPAddr isAlive bool - startup bool congestion Congestion verifiers []proxy.MacVerifier @@ -54,24 +39,6 @@ func (f Flow) String() string { return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) } -func InitiateFlow( - local func() string, - remote string, - vs []proxy.MacVerifier, - gs []proxy.MacGenerator, - c Congestion, - keepalive time.Duration, -) (*InitiatedFlow, error) { - f := InitiatedFlow{ - Local: local, - Remote: remote, - Flow: newFlow(c, vs, gs), - keepalive: keepalive, - } - - return &f, nil -} - func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { return Flow{ inboundDatagrams: make(chan []byte), @@ -81,102 +48,6 @@ func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow } } -func (f *InitiatedFlow) Reconnect(ctx context.Context) error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.isAlive { - return nil - } - - localAddr, err := net.ResolveUDPAddr("udp", f.Local()) - if err != nil { - return err - } - - remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote) - if err != nil { - return err - } - - conn, err := net.DialUDP("udp", localAddr, remoteAddr) - if err != nil { - return err - } - - f.writer = conn - f.startup = true - - // prod the connection once a second until we get an ack, then consider it alive - go func() { - seq, err := f.congestion.Sequence(ctx) - if err != nil { - return - } - - for !f.isAlive { - if ctx.Err() != nil { - return - } - - p := Packet{ - ack: 0, - nack: 0, - seq: seq, - data: proxy.SimplePacket(nil), - } - - _ = f.sendPacket(p) - time.Sleep(1 * time.Second) - } - }() - - go func() { - _, _ = f.produceInternal(ctx, false) - }() - go f.earlyUpdateLoop(ctx, f.keepalive) - - if err := f.readQueuePacket(ctx, conn); err != nil { - return err - } - - f.isAlive = true - f.startup = false - - go func() { - lockedAccept := func() { - f.mu.RLock() - defer f.mu.RUnlock() - - if err := f.readQueuePacket(ctx, conn); err != nil { - log.Println(err) - } - } - - for f.isAlive { - log.Println("alive and listening for packets") - lockedAccept() - } - log.Println("no longer alive") - }() - - return nil -} - -func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error { - f.mu.RLock() - defer f.mu.RUnlock() - - return f.Flow.Consume(ctx, p) -} - -func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) { - f.mu.RLock() - defer f.mu.RUnlock() - - return f.Flow.Produce(ctx) -} - func (f *Flow) IsAlive() bool { return f.isAlive } @@ -265,7 +136,7 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { } } -func (f *Flow) sendPacket(p Packet) error { +func (f *Flow) sendPacket(p proxy.Packet) error { b := p.Marshal() for _, g := range f.generators { @@ -302,13 +173,24 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) { } } -func (f *Flow) readQueuePacket(ctx context.Context, c PacketConn) error { - // TODO: Replace 6000 with MTU+header size +func (f *Flow) readPacket(ctx context.Context, c PacketConn) ([]byte, error) { buf := make([]byte, 6000) - n, _, err := c.ReadFromUDP(buf) - if err != nil { - return err + + if d, ok := ctx.Deadline(); ok { + if err := c.SetReadDeadline(d); err != nil { + return nil, err + } } - return f.queueDatagram(ctx, buf[:n]) + n, _, err := c.ReadFromUDP(buf) + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + if ctx.Err() != nil { + return nil, ctx.Err() + } + } + return nil, err + } + + return buf[:n], nil } diff --git a/udp/flow_test.go b/udp/flow_test.go index 2aa9b41..ebcc94b 100644 --- a/udp/flow_test.go +++ b/udp/flow_test.go @@ -107,7 +107,9 @@ func TestFlow_Produce(t *testing.T) { flowA.isAlive = true go func() { - err := flowA.readQueuePacket(context.Background(), testConn.SideB()) + p, err := flowA.readPacket(context.Background(), testConn.SideB()) + assert.Nil(t, err) + err = flowA.queueDatagram(context.Background(), p) assert.Nil(t, err) }() p, err := flowA.Produce(context.Background()) @@ -143,7 +145,9 @@ func TestFlow_Produce(t *testing.T) { flowA.isAlive = true go func() { - err := flowA.readQueuePacket(context.Background(), testConn.SideB()) + p, err := flowA.readPacket(context.Background(), testConn.SideB()) + assert.Nil(t, err) + err = flowA.queueDatagram(context.Background(), p) assert.Nil(t, err) }() p, err := flowA.Produce(context.Background()) diff --git a/udp/inbound_flow.go b/udp/inbound_flow.go new file mode 100644 index 0000000..896fdc3 --- /dev/null +++ b/udp/inbound_flow.go @@ -0,0 +1,142 @@ +package udp + +import ( + "context" + "log" + "mpbl3p/proxy" + "sync" + "time" +) + +type InboundFlow struct { + inboundDatagrams chan []byte + + mu sync.RWMutex + Flow +} + +func newInboundFlow(f Flow) (*InboundFlow, error) { + fi := InboundFlow{ + inboundDatagrams: make(chan []byte), + Flow: f, + } + + return &fi, nil +} + +func (f *InboundFlow) queueDatagram(ctx context.Context, p []byte) error { + select { + case f.inboundDatagrams <- p: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (f *InboundFlow) processPackets(ctx context.Context) { + for { + f.mu.Lock() + + var err error + for once := true; err != nil || once; once = false { + if ctx.Err() != nil { + return + } + + err = f.handleExchanges(ctx) + if err != nil { + log.Println(err) + } + } + + f.mu.Unlock() + + var p []byte + select { + case p = <-f.inboundDatagrams: + case <-ctx.Done(): + return + } + + // TODO: Check if p means redo exchanges + if false { + continue + } + + select { + case f.Flow.inboundDatagrams <- p: + case <-ctx.Done(): + return + } + } +} + +func (f *InboundFlow) handleExchanges(ctx context.Context) error { + var exchanges []proxy.Exchange + + if e, ok := f.congestion.(proxy.Exchange); ok { + exchanges = append(exchanges, e) + } + + var exchangeData [][]byte + + for _, e := range exchanges { + for once := true; !e.Complete() || once; once = false { + if err := func() (err error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var recv []byte + select { + case recv = <-f.inboundDatagrams: + case <-ctx.Done(): + return ctx.Err() + } + + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + recv, err = proxy.StripMac(recv, v) + if err != nil { + return err + } + } + + var resp, data []byte + if resp, data, err = e.Handle(ctx, recv); err != nil { + return err + } + + if data != nil { + exchangeData = append(exchangeData, data) + } + + if resp != nil { + if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil { + return err + } + } + + return nil + }(); err != nil { + return err + } + } + } + + return nil +} + +func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet) error { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Consume(ctx, p) +} + +func (f *InboundFlow) Produce(ctx context.Context) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(ctx) +} diff --git a/udp/listener.go b/udp/listener.go index b6d8001..6286e7a 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -16,9 +16,7 @@ type ComparableUdpAddress struct { func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { var ip [16]byte - for i, b := range []byte(address.IP) { - ip[i] = b - } + copy(ip[:], address.IP) return ComparableUdpAddress{ IP: ip, @@ -47,12 +45,7 @@ func NewListener( return err } - err = pconn.SetWriteBuffer(0) - if err != nil { - panic(err) - } - - receivedConnections := make(map[ComparableUdpAddress]*Flow) + receivedConnections := make(map[ComparableUdpAddress]*InboundFlow) go func() { for ctx.Err() == nil { @@ -70,10 +63,11 @@ func NewListener( } raddr := fromUdpAddress(*addr) - if f, exists := receivedConnections[raddr]; exists { + if fi, exists := receivedConnections[raddr]; exists { log.Println("existing flow. queuing...") - if err := f.queueDatagram(ctx, buf[:n]); err != nil { - + if err := fi.queueDatagram(ctx, buf[:n]); err != nil { + log.Println("error") + continue } log.Println("queued") continue @@ -94,21 +88,28 @@ func NewListener( f.raddr = addr f.isAlive = true + fi, err := newInboundFlow(f) + if err != nil { + log.Println(err) + continue + } + log.Printf("received new udp connection: %v\n", f) - go f.earlyUpdateLoop(ctx, 0) + go fi.processPackets(ctx) + go fi.earlyUpdateLoop(ctx, 0) - receivedConnections[raddr] = &f + receivedConnections[raddr] = fi if enableConsumers { - p.AddConsumer(ctx, &f) + p.AddConsumer(ctx, fi) } if enableProducers { - p.AddProducer(ctx, &f) + p.AddProducer(ctx, fi) } log.Println("handling...") - if err := f.queueDatagram(ctx, buf[:n]); err != nil { + if err := fi.queueDatagram(ctx, buf[:n]); err != nil { return } log.Println("handled") diff --git a/udp/outbound_flow.go b/udp/outbound_flow.go new file mode 100644 index 0000000..ae6751c --- /dev/null +++ b/udp/outbound_flow.go @@ -0,0 +1,181 @@ +package udp + +import ( + "context" + "errors" + "fmt" + "log" + "mpbl3p/proxy" + "net" + "sync" + "time" +) + +type OutboundFlow struct { + Local func() string + Remote string + + g proxy.MacGenerator + keepalive time.Duration + + mu sync.RWMutex + Flow +} + +func InitiateFlow( + local func() string, + remote string, + vs []proxy.MacVerifier, + gs []proxy.MacGenerator, + c Congestion, + keepalive time.Duration, +) (*OutboundFlow, error) { + f := OutboundFlow{ + Local: local, + Remote: remote, + Flow: newFlow(c, vs, gs), + keepalive: keepalive, + } + + return &f, nil +} + +func (f *OutboundFlow) String() string { + return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local(), f.Remote) +} + +func (f *OutboundFlow) Reconnect(ctx context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.isAlive { + return nil + } + + localAddr, err := net.ResolveUDPAddr("udp", f.Local()) + if err != nil { + return err + } + + remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote) + if err != nil { + return err + } + + conn, err := net.DialUDP("udp", localAddr, remoteAddr) + if err != nil { + return err + } + + f.writer = conn + + // prod the connection once a second until we get an ack, then consider it alive + var exchanges []proxy.Exchange + + if e, ok := f.congestion.(proxy.Exchange); ok { + exchanges = append(exchanges, e) + } + + var exchangeData [][]byte + + for _, e := range exchanges { + i, err := e.Initial(ctx) + if err != nil { + return err + } + + if err = f.sendPacket(proxy.SimplePacket(i)); err != nil { + return err + } + + for once := true; !e.Complete() || once; once = false { + if err := func() error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var recv []byte + if recv, err = f.readPacket(ctx, conn); err != nil { + return err + } + + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + recv, err = proxy.StripMac(recv, v) + if err != nil { + return err + } + } + + var resp, data []byte + if resp, data, err = e.Handle(ctx, recv); err != nil { + return err + } + + if data != nil { + exchangeData = append(exchangeData, data) + } + + if resp != nil { + if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil { + return err + } + } + + return nil + }(); err != nil { + return err + } + } + } + + go func() { + for _, d := range exchangeData { + if err := f.queueDatagram(ctx, d); err != nil { + return + } + } + + lockedAccept := func() { + f.mu.RLock() + defer f.mu.RUnlock() + + var p []byte + if p, err = f.readPacket(ctx, conn); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return + } + log.Println(err) + return + } + + if err := f.queueDatagram(ctx, p); err != nil { + return + } + + } + + for f.isAlive { + log.Println("alive and listening for packets") + lockedAccept() + } + log.Println("no longer alive") + }() + + f.isAlive = true + return nil +} + +func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet) error { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Consume(ctx, p) +} + +func (f *OutboundFlow) Produce(ctx context.Context) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(ctx) +}