diff --git a/udp/flow.go b/udp/flow.go index ee4a357..f538b1d 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -29,7 +29,8 @@ type Flow struct { isAlive bool congestion Congestion - v proxy.MacVerifier + verifiers []proxy.MacVerifier + generators []proxy.MacGenerator inboundDatagrams chan []byte } @@ -38,11 +39,12 @@ func (f Flow) String() string { return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) } -func newFlow(c Congestion, v proxy.MacVerifier) Flow { +func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow { return Flow{ inboundDatagrams: make(chan []byte), congestion: c, - v: v, + verifiers: vs, + generators: gs, } } @@ -50,7 +52,7 @@ func (f *Flow) IsAlive() bool { return f.isAlive } -func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) Consume(ctx context.Context, pp proxy.Packet) error { if !f.isAlive { return shared.ErrDeadConnection } @@ -73,18 +75,18 @@ func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerato nack: f.congestion.NextNack(), } - return f.sendPacket(p, g) + return f.sendPacket(p) } -func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } - return f.produceInternal(ctx, v, true) + return f.produceInternal(ctx, true) } -func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { +func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Packet, error) { for once := true; mustReturn || once; once = false { log.Println(f.congestion) @@ -95,12 +97,17 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet return nil, ctx.Err() } - b, err := proxy.StripMac(received, v) - if err != nil { - return nil, err + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + var err error + received, err = proxy.StripMac(received, v) + if err != nil { + return nil, err + } } - p, err := UnmarshalPacket(b) + p, err := UnmarshalPacket(received) if err != nil { return nil, err } @@ -109,7 +116,7 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet f.congestion.ReceivedPacket(p.seq, p.nack, p.ack) // 12 bytes for header + the MAC + a timestamp - if len(b) == 12+f.v.CodeLength()+8 { + if len(p.Contents()) == 0 { log.Println("handled keepalive/ack only packet") continue } @@ -129,9 +136,12 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error { } } -func (f *Flow) sendPacket(p proxy.Packet, g proxy.MacGenerator) error { +func (f *Flow) sendPacket(p proxy.Packet) error { b := p.Marshal() - b = proxy.AppendMac(b, g) + + for _, g := range f.generators { + b = proxy.AppendMac(b, g) + } if f.raddr == nil { _, err := f.writer.Write(b) @@ -142,7 +152,7 @@ func (f *Flow) sendPacket(p proxy.Packet, g proxy.MacGenerator) error { } } -func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) { +func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) { for f.isAlive { seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive) if err != nil { @@ -156,7 +166,7 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepal nack: f.congestion.NextNack(), } - err = f.sendPacket(p, g) + err = f.sendPacket(p) if err != nil { fmt.Printf("error sending early update packet: `%v`\n", err) } diff --git a/udp/flow_test.go b/udp/flow_test.go index d5bc2b1..0c61220 100644 --- a/udp/flow_test.go +++ b/udp/flow_test.go @@ -107,7 +107,7 @@ func TestFlow_Produce(t *testing.T) { flowA.isAlive = true go func() { - err := flowA.readQueuePacket(context.Background(), testConn.SideB()) + _, err := flowA.readPacket(context.Background(), testConn.SideB()) assert.Nil(t, err) }() p, err := flowA.Produce(context.Background()) diff --git a/udp/inbound_flow.go b/udp/inbound_flow.go index b802252..768dc44 100644 --- a/udp/inbound_flow.go +++ b/udp/inbound_flow.go @@ -9,16 +9,14 @@ import ( ) type InboundFlow struct { - g proxy.MacGenerator inboundDatagrams chan []byte mu sync.RWMutex Flow } -func newInboundFlow(f Flow, g proxy.MacGenerator) (*InboundFlow, error) { +func newInboundFlow(f Flow) (*InboundFlow, error) { fi := InboundFlow{ - g: g, inboundDatagrams: make(chan []byte), Flow: f, } @@ -95,8 +93,14 @@ func (f *InboundFlow) handleExchanges(ctx context.Context) error { return ctx.Err() } - if recv, err = proxy.StripMac(recv, f.v); err != nil { - return err + + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + recv, err = proxy.StripMac(recv, v) + if err != nil { + return err + } } var resp, data []byte @@ -109,7 +113,7 @@ func (f *InboundFlow) handleExchanges(ctx context.Context) error { } if resp != nil { - if err = f.sendPacket(proxy.SimplePacket(resp), f.g); err != nil { + if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil { return err } } @@ -124,16 +128,16 @@ func (f *InboundFlow) handleExchanges(ctx context.Context) error { return nil } -func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { +func (f *InboundFlow) Consume(ctx context.Context, p proxy.Packet) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(ctx, p, g) + return f.Flow.Consume(ctx, p) } -func (f *InboundFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *InboundFlow) Produce(ctx context.Context) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(ctx, v) + return f.Flow.Produce(ctx) } diff --git a/udp/listener.go b/udp/listener.go index 6c344b6..6286e7a 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -88,7 +88,7 @@ func NewListener( f.raddr = addr f.isAlive = true - fi, err := newInboundFlow(f, g) + fi, err := newInboundFlow(f) if err != nil { log.Println(err) continue diff --git a/udp/outbound_flow.go b/udp/outbound_flow.go index ed68a8a..ae6751c 100644 --- a/udp/outbound_flow.go +++ b/udp/outbound_flow.go @@ -25,16 +25,15 @@ type OutboundFlow struct { func InitiateFlow( local func() string, remote string, - v proxy.MacVerifier, - g proxy.MacGenerator, + vs []proxy.MacVerifier, + gs []proxy.MacGenerator, c Congestion, keepalive time.Duration, ) (*OutboundFlow, error) { f := OutboundFlow{ Local: local, Remote: remote, - Flow: newFlow(c, v), - g: g, + Flow: newFlow(c, vs, gs), keepalive: keepalive, } @@ -85,7 +84,7 @@ func (f *OutboundFlow) Reconnect(ctx context.Context) error { return err } - if err = f.sendPacket(proxy.SimplePacket(i), f.g); err != nil { + if err = f.sendPacket(proxy.SimplePacket(i)); err != nil { return err } @@ -99,8 +98,13 @@ func (f *OutboundFlow) Reconnect(ctx context.Context) error { return err } - if recv, err = proxy.StripMac(recv, f.v); err != nil { - return err + for i := range f.verifiers { + v := f.verifiers[len(f.verifiers)-i-1] + + recv, err = proxy.StripMac(recv, v) + if err != nil { + return err + } } var resp, data []byte @@ -113,7 +117,7 @@ func (f *OutboundFlow) Reconnect(ctx context.Context) error { } if resp != nil { - if err = f.sendPacket(proxy.SimplePacket(resp), f.g); err != nil { + if err = f.sendPacket(proxy.SimplePacket(resp)); err != nil { return err } } @@ -162,16 +166,16 @@ func (f *OutboundFlow) Reconnect(ctx context.Context) error { return nil } -func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { +func (f *OutboundFlow) Consume(ctx context.Context, p proxy.Packet) error { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Consume(ctx, p, g) + return f.Flow.Consume(ctx, p) } -func (f *OutboundFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { +func (f *OutboundFlow) Produce(ctx context.Context) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() - return f.Flow.Produce(ctx, v) + return f.Flow.Produce(ctx) }