Merge pull request 'multiples' (#23) from multiples into develop
All checks were successful
continuous-integration/drone/push Build is passing

Reviewed-on: #23
This commit is contained in:
JakeHillion 2021-05-13 21:30:44 +00:00
commit 842a3b578a
12 changed files with 348 additions and 130 deletions

View File

@ -7,6 +7,7 @@ import (
"mpbl3p/crypto" "mpbl3p/crypto"
"mpbl3p/crypto/sharedkey" "mpbl3p/crypto/sharedkey"
"mpbl3p/proxy" "mpbl3p/proxy"
"mpbl3p/replay"
"mpbl3p/tcp" "mpbl3p/tcp"
"mpbl3p/udp" "mpbl3p/udp"
"mpbl3p/udp/congestion" "mpbl3p/udp/congestion"
@ -16,13 +17,19 @@ import (
func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) { func (c Configuration) Build(ctx context.Context, source proxy.Source, sink proxy.Sink) (*proxy.Proxy, error) {
p := proxy.NewProxy(0) p := proxy.NewProxy(0)
var g func() proxy.MacGenerator var gs []func() proxy.MacGenerator
var v func() proxy.MacVerifier var vs []func() proxy.MacVerifier
if c.Host.ReplayProtection {
rp := replay.NewAntiReplay()
gs = append(gs, func() proxy.MacGenerator { return rp })
vs = append(vs, func() proxy.MacVerifier { return rp })
}
switch c.Host.Crypto { switch c.Host.Crypto {
case "None": case "None":
g = func() proxy.MacGenerator { return crypto.None{} } gs = append(gs, func() proxy.MacGenerator { return crypto.None{} })
v = func() proxy.MacVerifier { return crypto.None{} } vs = append(vs, func() proxy.MacVerifier { return crypto.None{} })
case "Blake2s": case "Blake2s":
key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey) key, err := base64.StdEncoding.DecodeString(c.Host.SharedKey)
if err != nil { if err != nil {
@ -31,14 +38,14 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
if _, err := sharedkey.NewBlake2s(key); err != nil { if _, err := sharedkey.NewBlake2s(key); err != nil {
return nil, err return nil, err
} }
g = func() proxy.MacGenerator { gs = append(gs, func() proxy.MacGenerator {
g, _ := sharedkey.NewBlake2s(key) g, _ := sharedkey.NewBlake2s(key)
return g return g
} })
v = func() proxy.MacVerifier { vs = append(vs, func() proxy.MacVerifier {
v, _ := sharedkey.NewBlake2s(key) v, _ := sharedkey.NewBlake2s(key)
return v return v
} })
} }
p.Source = source p.Source = source
@ -47,11 +54,11 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
for _, peer := range c.Peers { for _, peer := range c.Peers {
switch peer.Method { switch peer.Method {
case "TCP": case "TCP":
if err := buildTcp(ctx, p, peer, g, v); err != nil { if err := buildTcp(ctx, p, peer, gs, vs); err != nil {
return nil, err return nil, err
} }
case "UDP": case "UDP":
if err := buildUdp(ctx, p, peer, g, v); err != nil { if err := buildUdp(ctx, p, peer, gs, vs); err != nil {
return nil, err return nil, err
} }
} }
@ -60,7 +67,13 @@ func (c Configuration) Build(ctx context.Context, source proxy.Source, sink prox
return p, nil return p, nil
} }
func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { func buildTcp(
ctx context.Context,
p *proxy.Proxy,
peer Peer,
gs []func() proxy.MacGenerator,
vs []func() proxy.MacVerifier,
) error {
var laddr func() string var laddr func() string
if peer.LocalPort == 0 { if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -69,23 +82,23 @@ func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac
} }
if peer.RemoteHost != "" { if peer.RemoteHost != "" {
f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort)) f, err := tcp.InitiateFlow(laddr, fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), initialiseVerifiers(vs), initialiseGenerators(gs))
if err != nil { if err != nil {
return err return err
} }
if !peer.DisableConsumer { if !peer.DisableConsumer {
p.AddConsumer(ctx, f, g()) p.AddConsumer(ctx, f)
} }
if !peer.DisableProducer { if !peer.DisableProducer {
p.AddProducer(ctx, f, v()) p.AddProducer(ctx, f)
} }
return nil return nil
} }
err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer) err := tcp.NewListener(ctx, p, laddr(), vs, gs, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil { if err != nil {
return err return err
} }
@ -93,7 +106,13 @@ func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac
return nil return nil
} }
func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.MacGenerator, v func() proxy.MacVerifier) error { func buildUdp(
ctx context.Context,
p *proxy.Proxy,
peer Peer,
gs []func() proxy.MacGenerator,
vs []func() proxy.MacVerifier,
) error {
var laddr func() string var laddr func() string
if peer.LocalPort == 0 { if peer.LocalPort == 0 {
laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) } laddr = func() string { return fmt.Sprintf("%s:", peer.GetLocalHost()) }
@ -115,8 +134,8 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac
f, err := udp.InitiateFlow( f, err := udp.InitiateFlow(
laddr, laddr,
fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort),
v(), initialiseVerifiers(vs),
g(), initialiseGenerators(gs),
c(), c(),
time.Duration(peer.KeepAlive)*time.Second, time.Duration(peer.KeepAlive)*time.Second,
) )
@ -126,19 +145,35 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac
} }
if !peer.DisableConsumer { if !peer.DisableConsumer {
p.AddConsumer(ctx, f, g()) p.AddConsumer(ctx, f)
} }
if !peer.DisableProducer { if !peer.DisableProducer {
p.AddProducer(ctx, f, v()) p.AddProducer(ctx, f)
} }
return nil return nil
} }
err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer) err := udp.NewListener(ctx, p, laddr(), vs, gs, c, !peer.DisableConsumer, !peer.DisableProducer)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func initialiseVerifiers(vs []func() proxy.MacVerifier) (out []proxy.MacVerifier) {
out = make([]proxy.MacVerifier, len(vs))
for i, v := range vs {
out[i] = v()
}
return
}
func initialiseGenerators(gs []func() proxy.MacGenerator) (out []proxy.MacGenerator) {
out = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
out[i] = g()
}
return
}

View File

@ -41,6 +41,7 @@ type Host struct {
Crypto string `validate:"required,oneof=None Blake2s"` Crypto string `validate:"required,oneof=None Blake2s"`
SharedKey string `validate:"required_if=Crypto Blake2s"` SharedKey string `validate:"required_if=Crypto Blake2s"`
MTU uint `validate:"required,min=576"` MTU uint `validate:"required,min=576"`
ReplayProtection bool
} }
type Peer struct { type Peer struct {

View File

@ -4,19 +4,22 @@ import (
"mpbl3p/shared" "mpbl3p/shared"
) )
type AlmostUselessMac struct{} type AlmostUselessMac string
func (AlmostUselessMac) CodeLength() int { func (a AlmostUselessMac) CodeLength() int {
return 4 return len(a)
} }
func (AlmostUselessMac) Generate([]byte) []byte { func (a AlmostUselessMac) Generate([]byte) []byte {
return []byte{'a', 'b', 'c', 'd'} return []byte(a)
} }
func (u AlmostUselessMac) Verify(_, sum []byte) error { func (a AlmostUselessMac) Verify(_, sum []byte) error {
if !(sum[0] == 'a' && sum[1] == 'b' && sum[2] == 'c' && sum[3] == 'd') { for i, c := range sum {
if a[i] != c {
return shared.ErrBadChecksum return shared.ErrBadChecksum
} }
}
return nil return nil
} }

View File

@ -1,9 +1,6 @@
package proxy package proxy
import ( import "mpbl3p/shared"
"encoding/binary"
"time"
)
type Packet interface { type Packet interface {
Marshal() []byte Marshal() []byte
@ -22,17 +19,15 @@ func (p SimplePacket) Contents() []byte {
} }
func AppendMac(b []byte, g MacGenerator) []byte { func AppendMac(b []byte, g MacGenerator) []byte {
footer := make([]byte, 8)
unixTime := uint64(time.Now().Unix())
binary.LittleEndian.PutUint64(footer, unixTime)
b = append(b, footer...)
mac := g.Generate(b) mac := g.Generate(b)
return append(b, mac...) return append(b, mac...)
} }
func StripMac(b []byte, v MacVerifier) ([]byte, error) { func StripMac(b []byte, v MacVerifier) ([]byte, error) {
if len(b) < v.CodeLength() {
return nil, shared.ErrNotEnoughBytes
}
data := b[:len(b)-v.CodeLength()] data := b[:len(b)-v.CodeLength()]
sum := b[len(b)-v.CodeLength():] sum := b[len(b)-v.CodeLength():]
@ -40,7 +35,5 @@ func StripMac(b []byte, v MacVerifier) ([]byte, error) {
return nil, err return nil, err
} }
// TODO: Verify timestamp return data, nil
return data[:len(data)-8], nil
} }

View File

@ -10,18 +10,18 @@ import (
func TestAppendMac(t *testing.T) { func TestAppendMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
testPacket := SimplePacket(testContent) testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal() testMarshalled := testPacket.Marshal()
appended := AppendMac(testMarshalled, testMac) appended := AppendMac(testMarshalled, testMac)
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
assert.Len(t, appended, len(testMarshalled)+8+4) assert.Len(t, appended, len(testMarshalled)+4)
}) })
t.Run("Mac", func(t *testing.T) { t.Run("Mac", func(t *testing.T) {
assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:]) assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):])
}) })
t.Run("Original", func(t *testing.T) { t.Run("Original", func(t *testing.T) {
@ -31,7 +31,7 @@ func TestAppendMac(t *testing.T) {
func TestStripMac(t *testing.T) { func TestStripMac(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
testPacket := SimplePacket(testContent) testPacket := SimplePacket(testContent)
testMarshalled := testPacket.Marshal() testMarshalled := testPacket.Marshal()

View File

@ -9,12 +9,12 @@ import (
type Producer interface { type Producer interface {
IsAlive() bool IsAlive() bool
Produce(context.Context, MacVerifier) (Packet, error) Produce(context.Context) (Packet, error)
} }
type Consumer interface { type Consumer interface {
IsAlive() bool IsAlive() bool
Consume(context.Context, Packet, MacGenerator) error Consume(context.Context, Packet) error
} }
type Reconnectable interface { type Reconnectable interface {
@ -67,7 +67,7 @@ func (p Proxy) Start() {
}() }()
} }
func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) { func (p Proxy) AddConsumer(ctx context.Context, c Consumer) {
go func() { go func() {
_, reconnectable := c.(Reconnectable) _, reconnectable := c.(Reconnectable)
@ -94,7 +94,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
log.Printf("closed consumer `%v` (context)\n", c) log.Printf("closed consumer `%v` (context)\n", c)
return return
case packet := <-p.proxyChan: case packet := <-p.proxyChan:
if err := c.Consume(ctx, packet, g); err != nil { if err := c.Consume(ctx, packet); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed consumer `%v` (context)\n", c) log.Printf("closed consumer `%v` (context)\n", c)
return return
@ -110,7 +110,7 @@ func (p Proxy) AddConsumer(ctx context.Context, c Consumer, g MacGenerator) {
}() }()
} }
func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) { func (p Proxy) AddProducer(ctx context.Context, pr Producer) {
go func() { go func() {
_, reconnectable := pr.(Reconnectable) _, reconnectable := pr.(Reconnectable)
@ -136,7 +136,7 @@ func (p Proxy) AddProducer(ctx context.Context, pr Producer, v MacVerifier) {
} }
for pr.IsAlive() { for pr.IsAlive() {
if packet, err := pr.Produce(ctx, v); err != nil { if packet, err := pr.Produce(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("closed producer `%v` (context)\n", pr) log.Printf("closed producer `%v` (context)\n", pr)
return return

View File

@ -42,10 +42,16 @@ type Flow struct {
toConsume, produced chan []byte toConsume, produced chan []byte
consumeErrors, produceErrors chan error consumeErrors, produceErrors chan error
generators []proxy.MacGenerator
verifiers []proxy.MacVerifier
} }
func NewFlow() Flow { func NewFlow(vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
return Flow{ return Flow{
verifiers: vs,
generators: gs,
toConsume: make(chan []byte), toConsume: make(chan []byte),
produced: make(chan []byte), produced: make(chan []byte),
consumeErrors: make(chan error), consumeErrors: make(chan error),
@ -53,11 +59,14 @@ func NewFlow() Flow {
} }
} }
func NewFlowConn(ctx context.Context, conn Conn) Flow { func NewFlowConn(ctx context.Context, conn Conn, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
f := Flow{ f := Flow{
conn: conn, conn: conn,
isAlive: true, isAlive: true,
generators: gs,
verifiers: vs,
toConsume: make(chan []byte), toConsume: make(chan []byte),
produced: make(chan []byte), produced: make(chan []byte),
consumeErrors: make(chan error), consumeErrors: make(chan error),
@ -78,12 +87,12 @@ func (f *Flow) IsAlive() bool {
return f.isAlive return f.isAlive
} }
func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) { func InitiateFlow(local func() string, remote string, vs []proxy.MacVerifier, gs []proxy.MacGenerator) (*InitiatedFlow, error) {
f := InitiatedFlow{ f := InitiatedFlow{
Local: local, Local: local,
Remote: remote, Remote: remote,
Flow: NewFlow(), Flow: NewFlow(vs, gs),
} }
return &f, nil return &f, nil
@ -125,21 +134,21 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
return nil return nil
} }
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p, g) return f.Flow.Consume(ctx, p)
} }
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Produce(ctx, v) return f.Flow.Produce(ctx)
} }
func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { func (f *Flow) Consume(ctx context.Context, p proxy.Packet) error {
if !f.isAlive { if !f.isAlive {
return shared.ErrDeadConnection return shared.ErrDeadConnection
} }
@ -151,8 +160,10 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator
default: default:
} }
marshalled := p.Marshal() data := p.Marshal()
data := proxy.AppendMac(marshalled, g) for _, g := range f.generators {
data = proxy.AppendMac(data, g)
}
prefixedData := make([]byte, len(data)+4) prefixedData := make([]byte, len(data)+4)
binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data)))
@ -167,7 +178,7 @@ func (f *Flow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator
return nil return nil
} }
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) {
if !f.isAlive { if !f.isAlive {
return nil, shared.ErrDeadConnection return nil, shared.ErrDeadConnection
} }
@ -183,12 +194,17 @@ func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet,
return nil, err return nil, err
} }
b, err := proxy.StripMac(data, v) for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
var err error
data, err = proxy.StripMac(data, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
return proxy.SimplePacket(b), nil return proxy.SimplePacket(data), nil
} }
func (f *Flow) consumeMarshalled(ctx context.Context) { func (f *Flow) consumeMarshalled(ctx context.Context) {

View File

@ -13,41 +13,75 @@ import (
func TestFlow_Consume(t *testing.T) { func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.SimplePacket(testContent) testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
testMac2 := mocks.AlmostUselessMac("efgh")
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA()) flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
err := flowA.Consume(context.Background(), testPacket, testMac) err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err) require.Nil(t, err)
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf) n, err := testConn.SideB().Read(buf)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, len(testContent)+8+4+4, n) assert.Equal(t, len(testContent)+4+4, n)
assert.Equal(t, uint32(len(testContent)+8+4), binary.LittleEndian.Uint32(buf[:len(buf)-4])) assert.Equal(t, uint32(len(testContent)+4), binary.LittleEndian.Uint32(buf[:len(buf)-4]))
})
t.Run("MultipleGeneratorsLength", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf)
require.Nil(t, err)
assert.Equal(t, len(testContent)+4+4+4, n)
assert.Equal(t, uint32(len(testContent)+4+4), binary.LittleEndian.Uint32(buf[:len(buf)-4]))
})
t.Run("MultipleGeneratorsOrder", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, err := testConn.SideB().Read(buf)
require.Nil(t, err)
assert.Equal(t, len(testContent)+4+4+4, n)
assert.Equal(t, "abcdefgh", string(buf[n-8:n]))
}) })
} }
func TestFlow_Produce(t *testing.T) { func TestFlow_Produce(t *testing.T) {
testContent := "A test string is the content of this packet." testContent := "A test string is the content of this packet."
testMarshalled := []byte("0000" + testContent + "00000000abcd") testMarshalled := []byte("0000" + testContent + "abcd")
binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4)) binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4))
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
testMac2 := mocks.AlmostUselessMac("efgh")
t.Run("Length", func(t *testing.T) { t.Run("Length", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA()) flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(context.Background(), testMac) p, err := flowA.Produce(context.Background())
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, len(testContent), len(p.Contents())) assert.Equal(t, len(testContent), len(p.Contents()))
}) })
@ -55,12 +89,29 @@ func TestFlow_Produce(t *testing.T) {
t.Run("Value", func(t *testing.T) { t.Run("Value", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiStreamConn(100) testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA()) flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
_, err := testConn.SideB().Write(testMarshalled) _, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
p, err := flowA.Produce(context.Background(), testMac) p, err := flowA.Produce(context.Background())
require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents()))
})
t.Run("MultipleVerifiersStrip", func(t *testing.T) {
testContent := "A test string is the content of this packet."
testMarshalled := []byte("0000" + testContent + "abcdefgh")
binary.LittleEndian.PutUint32(testMarshalled, uint32(len(testMarshalled)-4))
testConn := mocks.NewMockPerfectBiStreamConn(100)
flowA := NewFlowConn(context.Background(), testConn.SideA(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
_, err := testConn.SideB().Write(testMarshalled)
require.Nil(t, err)
p, err := flowA.Produce(context.Background())
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, testContent, string(p.Contents())) assert.Equal(t, testContent, string(p.Contents()))
}) })

View File

@ -7,7 +7,15 @@ import (
"net" "net"
) )
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error { func NewListener(
ctx context.Context,
p *proxy.Proxy,
local string,
vs []func() proxy.MacVerifier,
gs []func() proxy.MacGenerator,
enableConsumers bool,
enableProducers bool,
) error {
laddr, err := net.ResolveTCPAddr("tcp", local) laddr, err := net.ResolveTCPAddr("tcp", local)
if err != nil { if err != nil {
return err return err
@ -29,15 +37,24 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro
panic(err) panic(err)
} }
f := NewFlowConn(ctx, conn) var verifiers = make([]proxy.MacVerifier, len(vs))
for i, v := range vs {
verifiers[i] = v()
}
var generators = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
generators[i] = g()
}
f := NewFlowConn(ctx, conn, verifiers, generators)
log.Printf("received new tcp connection: %v\n", f) log.Printf("received new tcp connection: %v\n", f)
if enableConsumers { if enableConsumers {
p.AddConsumer(ctx, &f, g()) p.AddConsumer(ctx, &f)
} }
if enableProducers { if enableProducers {
p.AddProducer(ctx, &f, v()) p.AddProducer(ctx, &f)
} }
} }
}() }()

View File

@ -26,7 +26,6 @@ type InitiatedFlow struct {
Local func() string Local func() string
Remote string Remote string
g proxy.MacGenerator
keepalive time.Duration keepalive time.Duration
mu sync.RWMutex mu sync.RWMutex
@ -45,7 +44,8 @@ type Flow struct {
startup bool startup bool
congestion Congestion congestion Congestion
v proxy.MacVerifier verifiers []proxy.MacVerifier
generators []proxy.MacGenerator
inboundDatagrams chan []byte inboundDatagrams chan []byte
} }
@ -57,27 +57,27 @@ func (f Flow) String() string {
func InitiateFlow( func InitiateFlow(
local func() string, local func() string,
remote string, remote string,
v proxy.MacVerifier, vs []proxy.MacVerifier,
g proxy.MacGenerator, gs []proxy.MacGenerator,
c Congestion, c Congestion,
keepalive time.Duration, keepalive time.Duration,
) (*InitiatedFlow, error) { ) (*InitiatedFlow, error) {
f := InitiatedFlow{ f := InitiatedFlow{
Local: local, Local: local,
Remote: remote, Remote: remote,
Flow: newFlow(c, v), Flow: newFlow(c, vs, gs),
g: g,
keepalive: keepalive, keepalive: keepalive,
} }
return &f, nil return &f, nil
} }
func newFlow(c Congestion, v proxy.MacVerifier) Flow { func newFlow(c Congestion, vs []proxy.MacVerifier, gs []proxy.MacGenerator) Flow {
return Flow{ return Flow{
inboundDatagrams: make(chan []byte), inboundDatagrams: make(chan []byte),
congestion: c, congestion: c,
v: v, verifiers: vs,
generators: gs,
} }
} }
@ -126,15 +126,15 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
data: proxy.SimplePacket(nil), data: proxy.SimplePacket(nil),
} }
_ = f.sendPacket(p, f.g) _ = f.sendPacket(p)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
}() }()
go func() { go func() {
_, _ = f.produceInternal(ctx, f.v, false) _, _ = f.produceInternal(ctx, false)
}() }()
go f.earlyUpdateLoop(ctx, f.g, f.keepalive) go f.earlyUpdateLoop(ctx, f.keepalive)
if err := f.readQueuePacket(ctx, conn); err != nil { if err := f.readQueuePacket(ctx, conn); err != nil {
return err return err
@ -163,25 +163,25 @@ func (f *InitiatedFlow) Reconnect(ctx context.Context) error {
return nil return nil
} }
func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet, g proxy.MacGenerator) error { func (f *InitiatedFlow) Consume(ctx context.Context, p proxy.Packet) error {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Consume(ctx, p, g) return f.Flow.Consume(ctx, p)
} }
func (f *InitiatedFlow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { func (f *InitiatedFlow) Produce(ctx context.Context) (proxy.Packet, error) {
f.mu.RLock() f.mu.RLock()
defer f.mu.RUnlock() defer f.mu.RUnlock()
return f.Flow.Produce(ctx, v) return f.Flow.Produce(ctx)
} }
func (f *Flow) IsAlive() bool { func (f *Flow) IsAlive() bool {
return f.isAlive return f.isAlive
} }
func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerator) error { func (f *Flow) Consume(ctx context.Context, pp proxy.Packet) error {
if !f.isAlive { if !f.isAlive {
return shared.ErrDeadConnection return shared.ErrDeadConnection
} }
@ -204,18 +204,18 @@ func (f *Flow) Consume(ctx context.Context, pp proxy.Packet, g proxy.MacGenerato
nack: f.congestion.NextNack(), nack: f.congestion.NextNack(),
} }
return f.sendPacket(p, g) return f.sendPacket(p)
} }
func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, error) { func (f *Flow) Produce(ctx context.Context) (proxy.Packet, error) {
if !f.isAlive { if !f.isAlive {
return nil, shared.ErrDeadConnection return nil, shared.ErrDeadConnection
} }
return f.produceInternal(ctx, v, true) return f.produceInternal(ctx, true)
} }
func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { func (f *Flow) produceInternal(ctx context.Context, mustReturn bool) (proxy.Packet, error) {
for once := true; mustReturn || once; once = false { for once := true; mustReturn || once; once = false {
log.Println(f.congestion) log.Println(f.congestion)
@ -226,12 +226,17 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet
return nil, ctx.Err() return nil, ctx.Err()
} }
b, err := proxy.StripMac(received, v) for i := range f.verifiers {
v := f.verifiers[len(f.verifiers)-i-1]
var err error
received, err = proxy.StripMac(received, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
p, err := UnmarshalPacket(b) p, err := UnmarshalPacket(received)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,7 +245,7 @@ func (f *Flow) produceInternal(ctx context.Context, v proxy.MacVerifier, mustRet
f.congestion.ReceivedPacket(p.seq, p.nack, p.ack) f.congestion.ReceivedPacket(p.seq, p.nack, p.ack)
// 12 bytes for header + the MAC + a timestamp // 12 bytes for header + the MAC + a timestamp
if len(b) == 12+f.v.CodeLength()+8 { if len(p.Contents()) == 0 {
log.Println("handled keepalive/ack only packet") log.Println("handled keepalive/ack only packet")
continue continue
} }
@ -260,9 +265,12 @@ func (f *Flow) queueDatagram(ctx context.Context, p []byte) error {
} }
} }
func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { func (f *Flow) sendPacket(p Packet) error {
b := p.Marshal() b := p.Marshal()
for _, g := range f.generators {
b = proxy.AppendMac(b, g) b = proxy.AppendMac(b, g)
}
if f.raddr == nil { if f.raddr == nil {
_, err := f.writer.Write(b) _, err := f.writer.Write(b)
@ -273,7 +281,7 @@ func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error {
} }
} }
func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepalive time.Duration) { func (f *Flow) earlyUpdateLoop(ctx context.Context, keepalive time.Duration) {
for f.isAlive { for f.isAlive {
seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive) seq, err := f.congestion.AwaitEarlyUpdate(ctx, keepalive)
if err != nil { if err != nil {
@ -287,7 +295,7 @@ func (f *Flow) earlyUpdateLoop(ctx context.Context, g proxy.MacGenerator, keepal
nack: f.congestion.NextNack(), nack: f.congestion.NextNack(),
} }
err = f.sendPacket(p, g) err = f.sendPacket(p)
if err != nil { if err != nil {
fmt.Printf("error sending early update packet: `%v`\n", err) fmt.Printf("error sending early update packet: `%v`\n", err)
} }

View File

@ -15,17 +15,17 @@ import (
func TestFlow_Consume(t *testing.T) { func TestFlow_Consume(t *testing.T) {
testContent := []byte("A test string is the content of this packet.") testContent := []byte("A test string is the content of this packet.")
testPacket := proxy.SimplePacket(testContent) testPacket := proxy.SimplePacket(testContent)
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
t.Run("Length", func(t *testing.T) { t.Run("SingleGeneratorLength", func(t *testing.T) {
testConn := mocks.NewMockPerfectBiPacketConn(10) testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), testMac) flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA.writer = testConn.SideB() flowA.writer = testConn.SideB()
flowA.isAlive = true flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket, testMac) err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err) require.Nil(t, err)
buf := make([]byte, 100) buf := make([]byte, 100)
@ -33,7 +33,50 @@ func TestFlow_Consume(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC // 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+8+4, n) assert.Equal(t, len(testContent)+12+4, n)
})
t.Run("MultipleGeneratorsLength", func(t *testing.T) {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
assert.Equal(t, len(testContent)+12+4+4, n)
})
t.Run("MultipleGeneratorsOrder", func(t *testing.T) {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
err := flowA.Consume(context.Background(), testPacket)
require.Nil(t, err)
buf := make([]byte, 100)
n, _, err := testConn.SideA().ReadFromUDP(buf)
require.Nil(t, err)
// 12 header, 8 timestamp, 4 MAC
require.Equal(t, len(testContent)+12+4+4, n)
macs := string(buf[n-8 : n])
assert.Equal(t, "abcdefgh", macs)
}) })
} }
@ -45,7 +88,7 @@ func TestFlow_Produce(t *testing.T) {
seq: 128, seq: 128,
data: proxy.SimplePacket(testContent), data: proxy.SimplePacket(testContent),
} }
testMac := mocks.AlmostUselessMac{} testMac := mocks.AlmostUselessMac("abcd")
testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac) testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac)
@ -58,7 +101,7 @@ func TestFlow_Produce(t *testing.T) {
_, err := testConn.SideA().Write(testMarshalled) _, err := testConn.SideA().Write(testMarshalled)
require.Nil(t, err) require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), testMac) flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac}, []proxy.MacGenerator{testMac})
flowA.writer = testConn.SideB() flowA.writer = testConn.SideB()
flowA.isAlive = true flowA.isAlive = true
@ -67,7 +110,43 @@ func TestFlow_Produce(t *testing.T) {
err := flowA.readQueuePacket(context.Background(), testConn.SideB()) err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err) assert.Nil(t, err)
}() }()
p, err := flowA.Produce(context.Background(), testMac) p, err := flowA.Produce(context.Background())
require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent))
done <- struct{}{}
}()
timer := time.NewTimer(500 * time.Millisecond)
select {
case <-done:
case <-timer.C:
fmt.Println("timed out")
t.FailNow()
}
})
t.Run("MultipleVerifiersStrip", func(t *testing.T) {
done := make(chan struct{})
go func() {
testMac2 := mocks.AlmostUselessMac("efgh")
testConn := mocks.NewMockPerfectBiPacketConn(10)
_, err := testConn.SideA().Write(proxy.AppendMac(testMarshalled, testMac2))
require.Nil(t, err)
flowA := newFlow(congestion.NewNone(), []proxy.MacVerifier{testMac, testMac2}, []proxy.MacGenerator{testMac, testMac2})
flowA.writer = testConn.SideB()
flowA.isAlive = true
go func() {
err := flowA.readQueuePacket(context.Background(), testConn.SideB())
assert.Nil(t, err)
}()
p, err := flowA.Produce(context.Background())
require.Nil(t, err) require.Nil(t, err)
assert.Len(t, p.Contents(), len(testContent)) assert.Len(t, p.Contents(), len(testContent))

View File

@ -27,7 +27,16 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress {
} }
} }
func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error { func NewListener(
ctx context.Context,
p *proxy.Proxy,
local string,
vs []func() proxy.MacVerifier,
gs []func() proxy.MacGenerator,
c func() Congestion,
enableConsumers bool,
enableProducers bool,
) error {
laddr, err := net.ResolveUDPAddr("udp", local) laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil { if err != nil {
return err return err
@ -70,10 +79,16 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro
continue continue
} }
v := v() var verifiers = make([]proxy.MacVerifier, len(vs))
g := g() for i, v := range vs {
verifiers[i] = v()
}
var generators = make([]proxy.MacGenerator, len(gs))
for i, g := range gs {
generators[i] = g()
}
f := newFlow(c(), v) f := newFlow(c(), verifiers, generators)
f.writer = pconn f.writer = pconn
f.raddr = addr f.raddr = addr
@ -81,15 +96,15 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro
log.Printf("received new udp connection: %v\n", f) log.Printf("received new udp connection: %v\n", f)
go f.earlyUpdateLoop(ctx, g, 0) go f.earlyUpdateLoop(ctx, 0)
receivedConnections[raddr] = &f receivedConnections[raddr] = &f
if enableConsumers { if enableConsumers {
p.AddConsumer(ctx, &f, g) p.AddConsumer(ctx, &f)
} }
if enableProducers { if enableProducers {
p.AddProducer(ctx, &f, v) p.AddProducer(ctx, &f)
} }
log.Println("handling...") log.Println("handling...")