diff --git a/config/builder.go b/config/builder.go index c5bea4f..76bcb0a 100644 --- a/config/builder.go +++ b/config/builder.go @@ -7,6 +7,7 @@ import ( "mpbl3p/crypto" "mpbl3p/crypto/sharedkey" "mpbl3p/proxy" + "mpbl3p/replay" "mpbl3p/tcp" "mpbl3p/udp" "mpbl3p/udp/congestion" @@ -16,13 +17,19 @@ import ( func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { p := proxy.NewProxy(0) - var g func() proxy.MacGenerator - var v func() proxy.MacVerifier + var gs []func() proxy.MacGenerator + var vs []func() proxy.MacVerifier + + if c.Host.ReplayProtection { + rp := replay.NewAntiReplay() + gs = append(gs, func() proxy.MacGenerator { return rp }) + vs = append(vs, func() proxy.MacVerifier { return rp }) + } switch c.Host.Crypto { case "None": - g = func() proxy.MacGenerator { return crypto.None{} } - v = func() proxy.MacVerifier { return crypto.None{} } + gs = append(gs, func() proxy.MacGenerator { return crypto.None{} }) + vs = append(vs, func() proxy.MacVerifier { return crypto.None{} }) case "Blake2s": key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey) if err != nil { @@ -31,14 +38,14 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox if _, err := sharedkey.NewBlake2s(key); err != nil { return nil, err } - g = func() proxy.MacGenerator { + gs = append(gs, func() proxy.MacGenerator { g, _ := sharedkey.NewBlake2s(key) return g - } - v = func() proxy.MacVerifier { + }) + vs = append(vs, func() proxy.MacVerifier { v, _ := sharedkey.NewBlake2s(key) return v - } + }) } p.Source = source @@ -47,11 +54,11 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox for _, peer := range c.Peers { switch peer.Method { case "TCP": - if err := buildTcp(ctx, p, peer, g, v); err != nil { + if err := buildTcp(ctx, p, peer, gs, vs); err != nil { return nil, err } case "UDP": - if err := buildUdp(ctx, p, peer, g, v); err != nil { + if err := buildUdp(ctx, p, peer, gs, vs); err != nil { return nil, err } } @@ -60,7 +67,13 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox return p, nil } -func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { +func buildTcp( + ctx context.Context, + p *proxy.Proxy, + peer Peer, + gs []func() proxy.MacGenerator, + vs []func() proxy.MacVerifier, +) error { var laddr func() string if peer.LocalPort == 0 { laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } @@ -69,23 +82,23 @@ func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac } if peer.RemoteHost != "" { - f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort)) + f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), initialiseVerifiers(vs), initialiseGenerators(gs)) if err != nil { return err } if !peer.DisableConsumer { - p.AddConsumer(ctx, f, g()) + p.AddConsumer(ctx, f) } if !peer.DisableProducer { - p.AddProducer(ctx, f, v()) + p.AddProducer(ctx, f) } return nil } - err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer) + err := tcp.NewListener(ctx, p, laddr(), vs, gs, !peer.DisableConsumer, !peer.DisableProducer) if err != nil { return err } @@ -93,7 +106,13 @@ func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac return nil } -func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { +func buildUdp( + ctx context.Context, + p *proxy.Proxy, + peer Peer, + gs []func() proxy.MacGenerator, + vs []func() proxy.MacVerifier, +) error { var laddr func() string if peer.LocalPort == 0 { laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } @@ -115,8 +134,8 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac f, err := udp.InitiateFlow( laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), - v(), - g(), + initialiseVerifiers(vs), + initialiseGenerators(gs), c(), time.Duration(peer.KeepAlive)*time.Second, ) @@ -126,19 +145,35 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac } if !peer.DisableConsumer { - p.AddConsumer(ctx, f, g()) + p.AddConsumer(ctx, f) } if !peer.DisableProducer { - p.AddProducer(ctx, f, v()) + p.AddProducer(ctx, f) } return nil } - err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer) + err := udp.NewListener(ctx, p, laddr(), vs, gs, c, !peer.DisableConsumer, !peer.DisableProducer) if err != nil { return err } return nil } + +func initialiseVerifiers(vs []func() proxy.MacVerifier) (out []proxy.MacVerifier) { + out = make([]proxy.MacVerifier, len(vs)) + for i, v := range vs { + out[i] = v() + } + return +} + +func initialiseGenerators(gs []func() proxy.MacGenerator) (out []proxy.MacGenerator) { + out = make([]proxy.MacGenerator, len(gs)) + for i, g := range gs { + out[i] = g() + } + return +} diff --git a/config/config.go b/config/config.go index ed7bea5..9038e76 100644 --- a/config/config.go +++ b/config/config.go @@ -38,9 +38,10 @@ type Configuration struct { } type Host struct { - Crypto string `validate:"required,oneof=None Blake2s"` - SharedKey string `validate:"required_if=Crypto Blake2s"` - MTU uint `validate:"required,min=576"` + Crypto string `validate:"required,oneof=None Blake2s"` + SharedKey string `validate:"required_if=Crypto Blake2s"` + MTU uint `validate:"required,min=576"` + ReplayProtection bool } type Peer struct { diff --git a/mocks/mac.go b/mocks/mac.go index 5d7ab7a..caf99b3 100644 --- a/mocks/mac.go +++ b/mocks/mac.go @@ -4,19 +4,22 @@ import ( "mpbl3p/shared" ) -type AlmostUselessMac struct{} +type AlmostUselessMac string -func (AlmostUselessMac) CodeLength() int { - return 4 +func (a AlmostUselessMac) CodeLength() int { + return len(a) } -func (AlmostUselessMac) Generate([]byte) []byte { - return []byte{'a', 'b', 'c', 'd'} +func (a AlmostUselessMac) Generate([]byte) []byte { + return []byte(a) } -func (u AlmostUselessMac) Verify(_, sum []byte) error { - if !(sum[0] == 'a' && sum[1] == 'b' && sum[2] == 'c' && sum[3] == 'd') { - return shared.ErrBadChecksum +func (a AlmostUselessMac) Verify(_, sum []byte) error { + for i, c := range sum { + if a[i] != c { + return shared.ErrBadChecksum + } } + return nil } diff --git a/proxy/packet.go b/proxy/packet.go index 9e0f592..d3b8799 100644 --- a/proxy/packet.go +++ b/proxy/packet.go @@ -1,9 +1,6 @@ package proxy -import ( - "encoding/binary" - "time" -) +import "mpbl3p/shared" type Packet interface { Marshal() []byte @@ -22,17 +19,15 @@ func (p SimplePacket) Contents() []byte { } func AppendMac(b []byte, g MacGenerator) []byte { - footer := make([]byte, 8) - unixTime := uint64(time.Now().Unix()) - binary.LittleEndian.PutUint64(footer, unixTime) - - b = append(b, footer...) - mac := g.Generate(b) return append(b, mac...) } func StripMac(b []byte, v MacVerifier) ([]byte, error) { + if len(b) < v.CodeLength() { + return nil, shared.ErrNotEnoughBytes + } + data := b[:len(b)-v.CodeLength()] sum := b[len(b)-v.CodeLength():] @@ -40,7 +35,5 @@ func StripMac(b []byte, v MacVerifier) ([]byte, error) { return nil, err } - // TODO: Verify timestamp - - return data[:len(data)-8], nil + return data, nil } diff --git a/proxy/packet_test.go b/proxy/packet_test.go index af2d626..0bedc35 100644 --- a/proxy/packet_test.go +++ b/proxy/packet_test.go @@ -10,18 +10,18 @@ import ( func TestAppendMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") testPacket := SimplePacket(testContent) testMarshalled := testPacket.Marshal() appended := AppendMac(testMarshalled, testMac) t.Run("Length", func(t *testing.T) { - assert.Len(t, appended, len(testMarshalled)+8+4) + assert.Len(t, appended, len(testMarshalled)+4) }) t.Run("Mac", func(t *testing.T) { - assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:]) + assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):]) }) t.Run("Original", func(t *testing.T) { @@ -31,7 +31,7 @@ func TestAppendMac(t *testing.T) { func TestStripMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") testPacket := SimplePacket(testContent) testMarshalled := testPacket.Marshal() diff --git a/proxy/proxy.go b/proxy/proxy.go index 593d65b..1bc535c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -9,12 +9,12 @@ import ( type Producer interface { IsAlive() bool - Produce(context.Context, MacVerifier) (Packet, error) + Produce(context.Context) (Packet, error) } type Consumer interface { IsAlive() bool - Consume(context.Context, Packet, MacGenerator) error + Consume(context.Context, Packet) error } type Reconnectable interface { @@ -67,7 +67,7 @@ func (p Proxy) Start() { }() } -func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { +func (p Proxy) AddConsumer(ctx context.Context, c Consumer) { go func() { _, reconnectable := c.(Reconnectable) @@ -94,7 +94,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { log.Printf("closed consumer `%v` (context)\n", c) return case packet := <-p.proxyChan: - if err := c.Consume(ctx, packet, g); err != nil { + if err := c.Consume(ctx, packet); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { log.Printf("closed consumer `%v` (context)\n", c) return @@ -110,7 +110,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { }() } -func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) { +func (p Proxy) AddProducer(ctx context.Context, pr Producer) { go func() { _, reconnectable := pr.(Reconnectable) @@ -136,7 +136,7 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) { } for pr.IsAlive() { - if packet, err := pr.Produce(ctx, v); err != nil { + if packet, err := pr.Produce(ctx); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { log.Printf("closed producer `%v` (context)\n", pr) return diff --git a/tcp/flow.go b/tcp/flow.go index bf54675..c0cdcba 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -42,10 +42,16 @@ type Flow struct { toConsume, produced chan []byte consumeErrors, produceErrors chan error + + generators []proxy.MacGenerator + verifiers []proxy.MacVerifier } -func NewFlow() Flow { +func NewFlow(vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { return Flow{ + verifiers: vs, + generators: gs, + toConsume: make(chan []byte), produced: make(chan []byte), consumeErrors: make(chan error), @@ -53,11 +59,14 @@ func NewFlow() Flow { } } -func NewFlowConn(ctx context.Context, conn Conn) Flow { +func NewFlowConn(ctx context.Context, conn Conn, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { f := Flow{ conn: conn, isAlive: true, + generators: gs, + verifiers: vs, + toConsume: make(chan []byte), produced: make(chan []byte), consumeErrors: make(chan error), @@ -78,12 +87,12 @@ func (f *Flow) IsAlive() bool { return f.isAlive } -func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) { +func InitiateFlow(local func() string, remote string, vs []proxy.MacVerifier, gs []proxy.MacGenerator) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, - Flow: NewFlow(), + Flow: NewFlow(vs, gs), } return &f, nil @@ -125,21 +134,21 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error { return nil } -func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(ctx, p, g) + return f.Flow.Consume(ctx, p) } -func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(ctx, v) + return f.Flow.Produce(ctx) } -func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) Consume(ctx context.Context, p proxy.Packet) error { if !f.isAlive { return shared.ErrDeadConnection } @@ -151,8 +160,10 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator default: } - marshalled := p.Marshal() - data := proxy.AppendMac(marshalled, g) + data := p.Marshal() + for _, g := range f.generators { + data = proxy.AppendMac(data, g) + } prefixedData := make([]byte, len(data)+4) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) @@ -167,7 +178,7 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator return nil } -func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } @@ -183,12 +194,17 @@ func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, return nil, err } - b, err := proxy.StripMac(data, v) - if err != nil { - return nil, err + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + var err error + data, err = proxy.StripMac(data, v) + if err != nil { + return nil, err + } } - return proxy.SimplePacket(b), nil + return proxy.SimplePacket(data), nil } func (f *Flow) consumeMarshalled(ctx context.Context) { diff --git a/tcp/flow_test.go b/tcp/flow_test.go index eba5a0a..559fa81 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -13,41 +13,75 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") testPacket := proxy.SimplePacket(testContent) - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") + testMac2 := mocks.AlmostUselessMac("efgh") t.Run("Length", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(context.Background(), testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac}) - err := flowA.Consume(context.Background(), testPacket, testMac) + err := flowA.Consume(context.Background(), testPacket) require.Nil(t, err) buf := make([]byte, 100) n, err := testConn.SideB().Read(buf) require.Nil(t, err) - assert.Equal(t, len(testContent)+8+4+4, n) - assert.Equal(t, uint32(len(testContent)+8+4), binary.LittleEndian.Uint32(buf[:len(buf)-4])) + assert.Equal(t, len(testContent)+4+4, n) + assert.Equal(t, uint32(len(testContent)+4), binary.LittleEndian.Uint32(buf[:len(buf)-4])) + }) + + t.Run("MultipleGeneratorsLength", func(t *testing.T) { + testConn := mocks.NewMockPerfectBiStreamConn(100) + + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + err := flowA.Consume(context.Background(), testPacket) + require.Nil(t, err) + + buf := make([]byte, 100) + n, err := testConn.SideB().Read(buf) + require.Nil(t, err) + + assert.Equal(t, len(testContent)+4+4+4, n) + assert.Equal(t, uint32(len(testContent)+4+4), binary.LittleEndian.Uint32(buf[:len(buf)-4])) + }) + + t.Run("MultipleGeneratorsOrder", func(t *testing.T) { + testConn := mocks.NewMockPerfectBiStreamConn(100) + + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + err := flowA.Consume(context.Background(), testPacket) + require.Nil(t, err) + + buf := make([]byte, 100) + n, err := testConn.SideB().Read(buf) + require.Nil(t, err) + + assert.Equal(t, len(testContent)+4+4+4, n) + assert.Equal(t, "abcdefgh", string(buf[n-8:n])) }) } func TestFlow_Produce(t *testing.T) { testContent := "A test string is the content of this packet." - testMarshalled := []byte("0000" + testContent + "00000000abcd") + testMarshalled := []byte("0000" + testContent + "abcd") binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4)) - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") + testMac2 := mocks.AlmostUselessMac("efgh") t.Run("Length", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(context.Background(), testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac}) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(context.Background(), testMac) + p, err := flowA.Produce(context.Background()) require.Nil(t, err) assert.Equal(t, len(testContent), len(p.Contents())) }) @@ -55,12 +89,29 @@ func TestFlow_Produce(t *testing.T) { t.Run("Value", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(context.Background(), testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac}) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) - p, err := flowA.Produce(context.Background(), testMac) + p, err := flowA.Produce(context.Background()) + require.Nil(t, err) + assert.Equal(t, testContent, string(p.Contents())) + }) + + t.Run("MultipleVerifiersStrip", func(t *testing.T) { + testContent := "A test string is the content of this packet." + testMarshalled := []byte("0000" + testContent + "abcdefgh") + binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4)) + + testConn := mocks.NewMockPerfectBiStreamConn(100) + + flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + _, err := testConn.SideB().Write(testMarshalled) + require.Nil(t, err) + + p, err := flowA.Produce(context.Background()) require.Nil(t, err) assert.Equal(t, testContent, string(p.Contents())) }) diff --git a/tcp/listener.go b/tcp/listener.go index cf87400..758cd5f 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -7,7 +7,15 @@ import ( "net" ) -func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error { +func NewListener( + ctx context.Context, + p *proxy.Proxy, + local string, + vs []func() proxy.MacVerifier, + gs []func() proxy.MacGenerator, + enableConsumers bool, + enableProducers bool, +) error { laddr, err := net.ResolveTCPAddr("tcp", local) if err != nil { return err @@ -29,15 +37,24 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro panic(err) } - f := NewFlowConn(ctx, conn) + var verifiers = make([]proxy.MacVerifier, len(vs)) + for i, v := range vs { + verifiers[i] = v() + } + var generators = make([]proxy.MacGenerator, len(gs)) + for i, g := range gs { + generators[i] = g() + } + + f := NewFlowConn(ctx, conn, verifiers, generators) log.Printf("received new tcp connection: %v\n", f) if enableConsumers { - p.AddConsumer(ctx, &f, g()) + p.AddConsumer(ctx, &f) } if enableProducers { - p.AddProducer(ctx, &f, v()) + p.AddProducer(ctx, &f) } } }() diff --git a/udp/flow.go b/udp/flow.go index d552d93..6ef2666 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -26,7 +26,6 @@ type InitiatedFlow struct { Local func() string Remote string - g proxy.MacGenerator keepalive time.Duration mu sync.RWMutex @@ -45,7 +44,8 @@ type Flow struct { startup bool congestion Congestion - v proxy.MacVerifier + verifiers []proxy.MacVerifier + generators []proxy.MacGenerator inboundDatagrams chan []byte } @@ -57,27 +57,27 @@ func (f Flow) String() string { func InitiateFlow( local func() string, remote string, - v proxy.MacVerifier, - g proxy.MacGenerator, + vs []proxy.MacVerifier, + gs []proxy.MacGenerator, c Congestion, keepalive time.Duration, ) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, - Flow: newFlow(c, v), - g: g, + Flow: newFlow(c, vs, gs), keepalive: keepalive, } return &f, nil } -func newFlow(c Congestion, v proxy.MacVerifier) Flow { +func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { return Flow{ inboundDatagrams: make(chan []byte), congestion: c, - v: v, + verifiers: vs, + generators: gs, } } @@ -126,15 +126,15 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error { data: proxy.SimplePacket(nil), } - _ = f.sendPacket(p, f.g) + _ = f.sendPacket(p) time.Sleep(1 * time.Second) } }() go func() { - _, _ = f.produceInternal(ctx, f.v, false) + _, _ = f.produceInternal(ctx, false) }() - go f.earlyUpdateLoop(ctx, f.g, f.keepalive) + go f.earlyUpdateLoop(ctx, f.keepalive) if err := f.readQueuePacket(ctx, conn); err != nil { return err @@ -163,25 +163,25 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error { return nil } -func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { +func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(ctx, p, g) + return f.Flow.Consume(ctx, p) } -func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(ctx, v) + return f.Flow.Produce(ctx) } func (f *Flow) IsAlive() bool { return f.isAlive } -func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) Consume(ctx context.Context, pp proxy.Packet) error { if !f.isAlive { return shared.ErrDeadConnection } @@ -204,18 +204,18 @@ func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerato nack: f.congestion.NextNack(), } - return f.sendPacket(p, g) + return f.sendPacket(p) } -func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } - return f.produceInternal(ctx, v, true) + return f.produceInternal(ctx, true) } -func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { +func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Packet, error) { for once := true; mustReturn || once; once = false { log.Println(f.congestion) @@ -226,12 +226,17 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet return nil, ctx.Err() } - b, err := proxy.StripMac(received, v) - if err != nil { - return nil, err + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + var err error + received, err = proxy.StripMac(received, v) + if err != nil { + return nil, err + } } - p, err := UnmarshalPacket(b) + p, err := UnmarshalPacket(received) if err != nil { return nil, err } @@ -240,7 +245,7 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet f.congestion.ReceivedPacket(p.seq, p.nack, p.ack) // 12 bytes for header + the MAC + a timestamp - if len(b) == 12+f.v.CodeLength()+8 { + if len(p.Contents()) == 0 { log.Println("handled keepalive/ack only packet") continue } @@ -260,9 +265,12 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { } } -func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { +func (f *Flow) sendPacket(p Packet) error { b := p.Marshal() - b = proxy.AppendMac(b, g) + + for _, g := range f.generators { + b = proxy.AppendMac(b, g) + } if f.raddr == nil { _, err := f.writer.Write(b) @@ -273,7 +281,7 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { } } -func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) { +func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) { for f.isAlive { seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive) if err != nil { @@ -287,7 +295,7 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepal nack: f.congestion.NextNack(), } - err = f.sendPacket(p, g) + err = f.sendPacket(p) if err != nil { fmt.Printf("error sending early update packet: `%v`\n", err) } diff --git a/udp/flow_test.go b/udp/flow_test.go index d044477..2aa9b41 100644 --- a/udp/flow_test.go +++ b/udp/flow_test.go @@ -15,17 +15,17 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") testPacket := proxy.SimplePacket(testContent) - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") - t.Run("Length", func(t *testing.T) { + t.Run("SingleGeneratorLength", func(t *testing.T) { testConn := mocks.NewMockPerfectBiPacketConn(10) - flowA := newFlow(congestion.NewNone(), testMac) + flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac}) flowA.writer = testConn.SideB() flowA.isAlive = true - err := flowA.Consume(context.Background(), testPacket, testMac) + err := flowA.Consume(context.Background(), testPacket) require.Nil(t, err) buf := make([]byte, 100) @@ -33,7 +33,50 @@ func TestFlow_Consume(t *testing.T) { require.Nil(t, err) // 12 header, 8 timestamp, 4 MAC - assert.Equal(t, len(testContent)+12+8+4, n) + assert.Equal(t, len(testContent)+12+4, n) + }) + + t.Run("MultipleGeneratorsLength", func(t *testing.T) { + testMac2 := mocks.AlmostUselessMac("efgh") + testConn := mocks.NewMockPerfectBiPacketConn(10) + + flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + err := flowA.Consume(context.Background(), testPacket) + require.Nil(t, err) + + buf := make([]byte, 100) + n, _, err := testConn.SideA().ReadFromUDP(buf) + require.Nil(t, err) + + // 12 header, 8 timestamp, 4 MAC + assert.Equal(t, len(testContent)+12+4+4, n) + }) + + t.Run("MultipleGeneratorsOrder", func(t *testing.T) { + testMac2 := mocks.AlmostUselessMac("efgh") + testConn := mocks.NewMockPerfectBiPacketConn(10) + + flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + err := flowA.Consume(context.Background(), testPacket) + require.Nil(t, err) + + buf := make([]byte, 100) + n, _, err := testConn.SideA().ReadFromUDP(buf) + require.Nil(t, err) + + // 12 header, 8 timestamp, 4 MAC + require.Equal(t, len(testContent)+12+4+4, n) + + macs := string(buf[n-8 : n]) + assert.Equal(t, "abcdefgh", macs) }) } @@ -45,7 +88,7 @@ func TestFlow_Produce(t *testing.T) { seq: 128, data: proxy.SimplePacket(testContent), } - testMac := mocks.AlmostUselessMac{} + testMac := mocks.AlmostUselessMac("abcd") testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac) @@ -58,7 +101,7 @@ func TestFlow_Produce(t *testing.T) { _, err := testConn.SideA().Write(testMarshalled) require.Nil(t, err) - flowA := newFlow(congestion.NewNone(), testMac) + flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac}) flowA.writer = testConn.SideB() flowA.isAlive = true @@ -67,7 +110,43 @@ func TestFlow_Produce(t *testing.T) { err := flowA.readQueuePacket(context.Background(), testConn.SideB()) assert.Nil(t, err) }() - p, err := flowA.Produce(context.Background(), testMac) + p, err := flowA.Produce(context.Background()) + + require.Nil(t, err) + assert.Len(t, p.Contents(), len(testContent)) + + done <- struct{}{} + }() + + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-done: + case <-timer.C: + fmt.Println("timed out") + t.FailNow() + } + }) + + t.Run("MultipleVerifiersStrip", func(t *testing.T) { + done := make(chan struct{}) + + go func() { + testMac2 := mocks.AlmostUselessMac("efgh") + testConn := mocks.NewMockPerfectBiPacketConn(10) + + _, err := testConn.SideA().Write(proxy.AppendMac(testMarshalled, testMac2)) + require.Nil(t, err) + + flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2}) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + go func() { + err := flowA.readQueuePacket(context.Background(), testConn.SideB()) + assert.Nil(t, err) + }() + p, err := flowA.Produce(context.Background()) require.Nil(t, err) assert.Len(t, p.Contents(), len(testContent)) diff --git a/udp/listener.go b/udp/listener.go index 7c91679..b6d8001 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -27,7 +27,16 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { } } -func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error { +func NewListener( + ctx context.Context, + p *proxy.Proxy, + local string, + vs []func() proxy.MacVerifier, + gs []func() proxy.MacGenerator, + c func() Congestion, + enableConsumers bool, + enableProducers bool, +) error { laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return err @@ -70,10 +79,16 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro continue } - v := v() - g := g() + var verifiers = make([]proxy.MacVerifier, len(vs)) + for i, v := range vs { + verifiers[i] = v() + } + var generators = make([]proxy.MacGenerator, len(gs)) + for i, g := range gs { + generators[i] = g() + } - f := newFlow(c(), v) + f := newFlow(c(), verifiers, generators) f.writer = pconn f.raddr = addr @@ -81,15 +96,15 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro log.Printf("received new udp connection: %v\n", f) - go f.earlyUpdateLoop(ctx, g, 0) + go f.earlyUpdateLoop(ctx, 0) receivedConnections[raddr] = &f if enableConsumers { - p.AddConsumer(ctx, &f, g) + p.AddConsumer(ctx, &f) } if enableProducers { - p.AddProducer(ctx, &f, v) + p.AddProducer(ctx, &f) } log.Println("handling...")