diff --git a/Makefile b/Makefile index 6f4ceda..2c4d8dd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ manual: docker run --rm -v /tmp:/tmp -v ${PWD}:/app -w /app golang:1.15-buster go build -o /tmp/mpbl3p - rsync -p /tmp/mpbl3p 10.21.10.3: - rsync -p /tmp/mpbl3p 10.21.10.4: + rsync -p /tmp/mpbl3p 10.21.12.101: + rsync -p /tmp/mpbl3p 10.21.12.102: manual-bsd: GOOS=freebsd go build -o /tmp/mpbl3p diff --git a/README.md b/README.md index 250b5d1..5dc8d52 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ component parts, or incorporated into the main application. # IPv4 Forwarding sysctl -w net.ipv4.ip_forward=1 + sysctl -w net.ipv4.conf.eth0.proxy_arp=1 # Tunnel addr/up ip addr add 172.19.152.2/31 dev nc0 @@ -84,4 +85,4 @@ component parts, or incorporated into the main application. #### Client -No configuration needed. Simply set the IP to that of the remote server/32 with no gateway. +No configuration needed. Simply set the IP to that of the remote server/32 with a gateway of 192.168.1.1. diff --git a/config/builder.go b/config/builder.go index 6f94d8b..c51e640 100644 --- a/config/builder.go +++ b/config/builder.go @@ -5,6 +5,9 @@ import ( "mpbl3p/proxy" "mpbl3p/tcp" "mpbl3p/tun" + "mpbl3p/udp" + "mpbl3p/udp/congestion" + "time" ) // TODO: Delete this code as soon as an alternative is available @@ -45,6 +48,11 @@ func (c Configuration) Build() (*proxy.Proxy, error) { if err != nil { return nil, err } + case "UDP": + err := buildUdp(p, peer) + if err != nil { + return nil, err + } } } @@ -58,10 +66,14 @@ func buildTcp(p *proxy.Proxy, peer Peer) error { fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), ) + if err != nil { + return err + } + p.AddConsumer(f) p.AddProducer(f, UselessMac{}) - return err + return nil } err := tcp.NewListener(p, fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), UselessMac{}) @@ -71,3 +83,48 @@ func buildTcp(p *proxy.Proxy, peer Peer) error { return nil } + +func buildUdp(p *proxy.Proxy, peer Peer) error { + var c func() udp.Congestion + switch peer.Congestion { + case "None": + c = func() udp.Congestion {return congestion.NewNone()} + default: + fallthrough + case "NewReno": + c = func() udp.Congestion {return congestion.NewNewReno()} + } + + if peer.RemoteHost != "" { + f, err := udp.InitiateFlow( + fmt.Sprintf("%s:", peer.LocalHost), + fmt.Sprintf("%s:%d", peer.RemoteHost, peer.RemotePort), + UselessMac{}, + UselessMac{}, + c(), + time.Duration(peer.KeepAlive)*time.Second, + ) + + if err != nil { + return err + } + + p.AddConsumer(f) + p.AddProducer(f, UselessMac{}) + + return nil + } + + err := udp.NewListener( + p, + fmt.Sprintf("%s:%d", peer.LocalHost, peer.LocalPort), + UselessMac{}, + UselessMac{}, + c, + ) + if err != nil { + return err + } + + return nil +} diff --git a/config/config.go b/config/config.go index 76de54c..73b086a 100644 --- a/config/config.go +++ b/config/config.go @@ -16,7 +16,7 @@ type Host struct { type Peer struct { PublicKey string `validate:"required"` - Method string `validate:"oneof=TCP"` + Method string `validate:"oneof=TCP UDP"` LocalHost string `validate:"omitempty,ip"` LocalPort uint `validate:"max=65535"` @@ -24,6 +24,8 @@ type Peer struct { RemoteHost string `validate:"required_with=RemotePort,omitempty,fqdn|ip"` RemotePort uint `validate:"required_with=RemoteHost,omitempty,max=65535"` + Congestion string `validate:"oneof=NewReno None"` + KeepAlive uint Timeout uint RetryWait uint diff --git a/go.mod b/go.mod index 22a3f7d..dd6edf3 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module mpbl3p go 1.15 require ( - github.com/go-playground/assert/v2 v2.0.1 github.com/go-playground/validator/v10 v10.4.1 github.com/pkg/taptun v0.0.0-20160424131934-bbbd335672ab github.com/smartystreets/goconvey v1.6.4 // indirect diff --git a/main.go b/main.go index d7f2633..493adad 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,14 @@ func main() { log.Println("loading config...") - c, err := config.LoadConfig("config.ini") + var configLoc string + if v, ok := os.LookupEnv("CONFIG_LOC"); ok { + configLoc = v + } else { + configLoc = "config.ini" + } + + c, err := config.LoadConfig(configLoc) if err != nil { panic(err) } diff --git a/mocks/conn.go b/mocks/conn.go deleted file mode 100644 index 3ad384f..0000000 --- a/mocks/conn.go +++ /dev/null @@ -1,67 +0,0 @@ -package mocks - -import "time" - -type MockPerfectBiConn struct { - directionA chan byte - directionB chan byte -} - -func NewMockPerfectBiConn(bufSize int) MockPerfectBiConn { - return MockPerfectBiConn{ - directionA: make(chan byte, bufSize), - directionB: make(chan byte, bufSize), - } -} - -func (bc MockPerfectBiConn) SideA() MockPerfectConn { - return MockPerfectConn{inbound: bc.directionA, outbound: bc.directionB} -} - -func (bc MockPerfectBiConn) SideB() MockPerfectConn { - return MockPerfectConn{inbound: bc.directionB, outbound: bc.directionA} -} - -type MockPerfectConn struct { - inbound chan byte - outbound chan byte -} - -func (c MockPerfectConn) SetWriteDeadline(time.Time) error { - return nil -} - -func (c MockPerfectConn) Read(p []byte) (n int, err error) { - for i := range p { - if i == 0 { - p[i] = <-c.inbound - } else { - select { - case b := <-c.inbound: - p[i] = b - default: - return i, nil - } - } - } - return len(p), nil -} - -func (c MockPerfectConn) Write(p []byte) (n int, err error) { - for _, b := range p { - c.outbound <- b - } - return len(p), nil -} - -func (c MockPerfectConn) NonBlockingRead(p []byte) (n int, err error) { - for i := range p { - select { - case b := <-c.inbound: - p[i] = b - default: - return i, nil - } - } - return len(p), nil -} diff --git a/mocks/mac.go b/mocks/mac.go index 60d2167..5d7ab7a 100644 --- a/mocks/mac.go +++ b/mocks/mac.go @@ -1,6 +1,8 @@ package mocks -import "mpbl3p/shared" +import ( + "mpbl3p/shared" +) type AlmostUselessMac struct{} diff --git a/mocks/packetconn.go b/mocks/packetconn.go new file mode 100644 index 0000000..9360db4 --- /dev/null +++ b/mocks/packetconn.go @@ -0,0 +1,62 @@ +package mocks + +import "net" + +type MockPerfectBiPacketConn struct { + directionA chan []byte + directionB chan []byte +} + +func NewMockPerfectBiPacketConn(bufSize int) MockPerfectBiPacketConn { + return MockPerfectBiPacketConn{ + directionA: make(chan []byte, bufSize), + directionB: make(chan []byte, bufSize), + } +} + +func (bc MockPerfectBiPacketConn) SideA() MockPerfectPacketConn { + return MockPerfectPacketConn{inbound: bc.directionA, outbound: bc.directionB} +} + +func (bc MockPerfectBiPacketConn) SideB() MockPerfectPacketConn { + return MockPerfectPacketConn{inbound: bc.directionB, outbound: bc.directionA} +} + +type MockPerfectPacketConn struct { + inbound chan []byte + outbound chan []byte +} + +func (c MockPerfectPacketConn) Write(b []byte) (int, error) { + c.outbound <- b + return len(b), nil +} + +func (c MockPerfectPacketConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + c.outbound <- b + return len(b), nil +} + +func (c MockPerfectPacketConn) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + } +} + +func (c MockPerfectPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + p := <-c.inbound + return copy(b, p), &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, nil +} + +func (c MockPerfectPacketConn) NonBlockingRead(p []byte) (n int, err error) { + select { + case b := <-c.inbound: + return copy(p, b), nil + default: + return 0, nil + } +} diff --git a/mocks/streamconn.go b/mocks/streamconn.go new file mode 100644 index 0000000..956d2d2 --- /dev/null +++ b/mocks/streamconn.go @@ -0,0 +1,95 @@ +package mocks + +import ( + "net" + "time" +) + +type MockPerfectBiStreamConn struct { + directionA chan byte + directionB chan byte +} + +func NewMockPerfectBiStreamConn(bufSize int) MockPerfectBiStreamConn { + return MockPerfectBiStreamConn{ + directionA: make(chan byte, bufSize), + directionB: make(chan byte, bufSize), + } +} + +func (bc MockPerfectBiStreamConn) SideA() MockPerfectStreamConn { + return MockPerfectStreamConn{inbound: bc.directionA, outbound: bc.directionB} +} + +func (bc MockPerfectBiStreamConn) SideB() MockPerfectStreamConn { + return MockPerfectStreamConn{inbound: bc.directionB, outbound: bc.directionA} +} + +type MockPerfectStreamConn struct { + inbound chan byte + outbound chan byte +} + +type Conn interface { + Read(b []byte) (n int, err error) + Write(b []byte) (n int, err error) + SetWriteDeadline(time.Time) error + + // For printing + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +func (c MockPerfectStreamConn) Read(p []byte) (n int, err error) { + for i := range p { + if i == 0 { + p[i] = <-c.inbound + } else { + select { + case b := <-c.inbound: + p[i] = b + default: + return i, nil + } + } + } + return len(p), nil +} + +func (c MockPerfectStreamConn) Write(p []byte) (n int, err error) { + for _, b := range p { + c.outbound <- b + } + return len(p), nil +} + +func (c MockPerfectStreamConn) SetWriteDeadline(time.Time) error { + return nil +} + +// Only used for printing flow information +func (c MockPerfectStreamConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 499, + } +} + +func (c MockPerfectStreamConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 500, + } +} + +func (c MockPerfectStreamConn) NonBlockingRead(p []byte) (n int, err error) { + for i := range p { + select { + case b := <-c.inbound: + p[i] = b + default: + return i, nil + } + } + return len(p), nil +} diff --git a/proxy/packet.go b/proxy/packet.go index 1588278..260b9f9 100644 --- a/proxy/packet.go +++ b/proxy/packet.go @@ -5,30 +5,27 @@ import ( "time" ) -type Packet struct { +type Packet interface { + Marshal() []byte + Contents() []byte +} + +type SimplePacket struct { Data []byte timestamp time.Time } // create a packet from the raw data of an IP packet -func NewPacket(data []byte) Packet { - return Packet{ +func NewSimplePacket(data []byte) Packet { + return SimplePacket{ Data: data, timestamp: time.Now(), } } // rebuild a packet from the wrapped format -func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) { - // the MAC is the last N bytes - data := raw[:len(raw)-verifier.CodeLength()] - sum := raw[len(raw)-verifier.CodeLength():] - - if err := verifier.Verify(data, sum); err != nil { - return Packet{}, err - } - - p := Packet{ +func UnmarshalSimplePacket(data []byte) (SimplePacket, error) { + p := SimplePacket{ Data: data[:len(data)-8], } @@ -39,22 +36,32 @@ func UnmarshalPacket(raw []byte, verifier MacVerifier) (Packet, error) { } // get the raw data of the IP packet -func (p Packet) Raw() []byte { +func (p SimplePacket) Marshal() []byte { + footer := make([]byte, 8) + + unixTime := uint64(p.timestamp.Unix()) + binary.LittleEndian.PutUint64(footer, unixTime) + + return append(p.Data, footer...) +} + +func (p SimplePacket) Contents() []byte { return p.Data } -// produce the wrapped format of a packet -func (p Packet) Marshal(generator MacGenerator) []byte { - // length of data + length of timestamp (8 byte) + length of checksum - slice := make([]byte, len(p.Data)+8+generator.CodeLength()) - - copy(slice, p.Data) - - unixTime := uint64(p.timestamp.Unix()) - binary.LittleEndian.PutUint64(slice[len(p.Data):], unixTime) - - mac := generator.Generate(slice) - copy(slice[len(p.Data)+8:], mac) - - return slice +func AppendMac(b []byte, g MacGenerator) []byte { + mac := g.Generate(b) + b = append(b, mac...) + return b +} + +func StripMac(b []byte, v MacVerifier) ([]byte, error) { + data := b[:len(b)-v.CodeLength()] + sum := b[len(b)-v.CodeLength():] + + if err := v.Verify(data, sum); err != nil { + return nil, err + } + + return data, nil } diff --git a/proxy/packet_test.go b/proxy/packet_test.go index 5e14843..d97a79e 100644 --- a/proxy/packet_test.go +++ b/proxy/packet_test.go @@ -4,31 +4,90 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "mpbl3p/mocks" + "mpbl3p/shared" "testing" ) func TestPacket_Marshal(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := NewPacket(testContent) - testMac := mocks.AlmostUselessMac{} + testPacket := NewSimplePacket(testContent) t.Run("Length", func(t *testing.T) { - marshalled := testPacket.Marshal(testMac) + marshalled := testPacket.Marshal() - assert.Len(t, marshalled, len(testContent)+8+4) + assert.Len(t, marshalled, len(testContent)+8) }) } func TestUnmarshalPacket(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := NewPacket(testContent) - testMac := mocks.AlmostUselessMac{} - testMarshalled := testPacket.Marshal(testMac) + testPacket := NewSimplePacket(testContent) + testMarshalled := testPacket.Marshal() t.Run("Length", func(t *testing.T) { - p, err := UnmarshalPacket(testMarshalled, testMac) + p, err := UnmarshalSimplePacket(testMarshalled) require.Nil(t, err) - assert.Len(t, p.Raw(), len(testContent)) + assert.Len(t, p.Contents(), len(testContent)) + }) + + t.Run("Contents", func(t *testing.T) { + p, err := UnmarshalSimplePacket(testMarshalled) + + require.Nil(t, err) + assert.Equal(t, p.Contents(), testContent) + }) +} + +func TestAppendMac(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testMac := mocks.AlmostUselessMac{} + testPacket := NewSimplePacket(testContent) + testMarshalled := testPacket.Marshal() + + appended := AppendMac(testMarshalled, testMac) + + t.Run("Length", func(t *testing.T) { + assert.Len(t, appended, len(testMarshalled)+4) + }) + + t.Run("Mac", func(t *testing.T) { + assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):]) + }) + + t.Run("Original", func(t *testing.T) { + assert.Equal(t, testMarshalled, appended[:len(testMarshalled)]) + }) +} + +func TestStripMac(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testMac := mocks.AlmostUselessMac{} + testPacket := NewSimplePacket(testContent) + testMarshalled := testPacket.Marshal() + + appended := AppendMac(testMarshalled, testMac) + + t.Run("Length", func(t *testing.T) { + cut, err := StripMac(appended, testMac) + + require.Nil(t, err) + assert.Len(t, cut, len(testMarshalled)) + }) + + t.Run("IncorrectMac", func(t *testing.T) { + badMac := make([]byte, len(testMarshalled)+4) + copy(badMac, testMarshalled) + copy(badMac[:len(testMarshalled)], "dcba") + _, err := StripMac(badMac, testMac) + + assert.Error(t, err, shared.ErrBadChecksum) + }) + + t.Run("Original", func(t *testing.T) { + cut, err := StripMac(appended, testMac) + + require.Nil(t, err) + assert.Equal(t, testMarshalled, cut) }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index e98a836..2bea389 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -75,13 +75,13 @@ func (p Proxy) AddConsumer(c Consumer) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", c) + log.Printf("attempting to connect consumer `%v`\n", c) err = c.(Reconnectable).Reconnect() if !once { time.Sleep(time.Second) } } - log.Printf("connected `%v`\n", c) + log.Printf("connected consumer `%v`\n", c) } for c.IsAlive() { @@ -92,7 +92,7 @@ func (p Proxy) AddConsumer(c Consumer) { } } - log.Printf("closed connection `%v`\n", c) + log.Printf("closed consumer `%v`\n", c) }() } @@ -104,13 +104,13 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { if reconnectable { var err error for once := true; err != nil || once; once = false { - log.Printf("attempting to connect `%v`\n", pr) + log.Printf("attempting to connect producer `%v`\n", pr) err = pr.(Reconnectable).Reconnect() if !once { time.Sleep(time.Second) } } - log.Printf("connected `%v`\n", pr) + log.Printf("connected producer `%v`\n", pr) } for pr.IsAlive() { @@ -123,6 +123,6 @@ func (p Proxy) AddProducer(pr Producer, v MacVerifier) { } } - log.Printf("closed connection `%v`\n", pr) + log.Printf("closed producer `%v`\n", pr) }() } diff --git a/tcp/flow.go b/tcp/flow.go index 26607e8..c626e49 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -3,6 +3,7 @@ package tcp import ( "encoding/binary" "errors" + "fmt" "io" "mpbl3p/proxy" "mpbl3p/shared" @@ -17,6 +18,10 @@ type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) SetWriteDeadline(time.Time) error + + // For printing + LocalAddr() net.Addr + RemoteAddr() net.Addr } type InitiatedFlow struct { @@ -28,11 +33,19 @@ type InitiatedFlow struct { Flow } +func (f *InitiatedFlow) String() string { + return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local, f.Remote) +} + type Flow struct { conn Conn isAlive bool } +func (f Flow) String() string { + return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr()) +} + func InitiateFlow(local, remote string) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, @@ -75,10 +88,6 @@ func (f *InitiatedFlow) Reconnect() error { return nil } -func (f *Flow) IsAlive() bool { - return f.isAlive -} - func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() @@ -86,12 +95,25 @@ func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { return f.Flow.Consume(p, g) } +func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(v) +} + +func (f *Flow) IsAlive() bool { + return f.isAlive +} + func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { if !f.isAlive { return shared.ErrDeadConnection } - data := p.Marshal(g) + marshalled := p.Marshal() + data := proxy.AppendMac(marshalled, g) + err = f.consumeMarshalled(data) if err != nil { f.isAlive = false @@ -112,25 +134,23 @@ func (f *Flow) consumeMarshalled(data []byte) error { return err } -func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { - f.mu.RLock() - defer f.mu.RUnlock() - - return f.Flow.Produce(v) -} - func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { - return proxy.Packet{}, shared.ErrDeadConnection + return nil, shared.ErrDeadConnection } data, err := f.produceMarshalled() if err != nil { f.isAlive = false - return proxy.Packet{}, err + return nil, err } - return proxy.UnmarshalPacket(data, v) + b, err := proxy.StripMac(data, v) + if err != nil { + return nil, err + } + + return proxy.UnmarshalSimplePacket(b) } func (f *Flow) produceMarshalled() ([]byte, error) { diff --git a/tcp/flow_test.go b/tcp/flow_test.go index 88fbf1c..6e0e7b3 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -2,7 +2,7 @@ package tcp import ( "encoding/binary" - "github.com/go-playground/assert/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "mpbl3p/mocks" "mpbl3p/proxy" @@ -11,11 +11,11 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := proxy.NewPacket(testContent) + testPacket := proxy.NewSimplePacket(testContent) testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) flowA := Flow{conn: testConn.SideA(), isAlive: true} @@ -39,7 +39,7 @@ func TestFlow_Produce(t *testing.T) { testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) flowA := Flow{conn: testConn.SideA(), isAlive: true} @@ -48,11 +48,11 @@ func TestFlow_Produce(t *testing.T) { p, err := flowA.Produce(testMac) require.Nil(t, err) - assert.Equal(t, len(testContent), len(p.Raw())) + assert.Equal(t, len(testContent), len(p.Contents())) }) t.Run("Value", func(t *testing.T) { - testConn := mocks.NewMockPerfectBiConn(100) + testConn := mocks.NewMockPerfectBiStreamConn(100) flowA := Flow{conn: testConn.SideA(), isAlive: true} @@ -61,6 +61,6 @@ func TestFlow_Produce(t *testing.T) { p, err := flowA.Produce(testMac) require.Nil(t, err) - assert.Equal(t, testContent, string(p.Raw())) + assert.Equal(t, testContent, string(p.Contents())) }) } diff --git a/tcp/listener.go b/tcp/listener.go index 5da9761..52892e0 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -31,7 +31,7 @@ func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier) error { f := Flow{conn: conn, isAlive: true} - log.Printf("received new connection: %v\n", f) + log.Printf("received new tcp connection: %v\n", f) p.AddConsumer(&f) p.AddProducer(&f, v) diff --git a/tun/tun.go b/tun/tun.go index 2f5c5cb..1b55eb1 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -61,14 +61,14 @@ func (t *SourceSink) Source() (proxy.Packet, error) { read, err := t.tun.Read(buf) if err != nil { - return proxy.Packet{}, err + return nil, err } if read == 0 { - return proxy.Packet{}, io.EOF + return nil, io.EOF } - return proxy.NewPacket(buf[:read]), nil + return proxy.NewSimplePacket(buf[:read]), nil } var good, bad float64 @@ -79,7 +79,7 @@ func (t *SourceSink) Sink(packet proxy.Packet) error { t.upMu.Unlock() } - _, err := t.tun.Write(packet.Raw()) + _, err := t.tun.Write(packet.Contents()) if err != nil { switch err.(type) { case *os.PathError: diff --git a/udp/congestion.go b/udp/congestion.go new file mode 100644 index 0000000..705a193 --- /dev/null +++ b/udp/congestion.go @@ -0,0 +1,17 @@ +package udp + +import "time" + +type Congestion interface { + Sequence() uint32 + ReceivedPacket(seq uint32) + + ReceivedAck(uint32) + NextAck() uint32 + + ReceivedNack(uint32) + NextNack() uint32 + + AwaitEarlyUpdate(keepalive time.Duration) uint32 + Reset() +} diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go new file mode 100644 index 0000000..5b2e179 --- /dev/null +++ b/udp/congestion/newreno.go @@ -0,0 +1,226 @@ +package congestion + +import ( + "fmt" + "log" + "math" + "mpbl3p/utils" + "sync" + "sync/atomic" + "time" +) + +const RttExponentialFactor = 0.1 + +type NewReno struct { + sequence chan uint32 + keepalive chan bool + + outboundTimes, inboundTimes map[uint32]time.Time + outboundTimesLock sync.Mutex + inboundTimesLock sync.RWMutex + + ack, lastAck uint32 + nack, lastNack uint32 + + slowStart bool + rtt float64 + windowSize int32 + windowCount int32 + inFlight int32 + + ackNotifier chan struct{} + + lastSent time.Time + hasAcked bool + + acksToSend utils.Uint32Heap + acksToSendLock sync.Mutex +} + +func (c *NewReno) String() string { + return fmt.Sprintf("{NewReno %t %d %d %d %d}", c.slowStart, c.windowSize, c.inFlight, c.lastAck, c.lastNack) +} + +func NewNewReno() *NewReno { + c := NewReno{ + sequence: make(chan uint32), + ackNotifier: make(chan struct{}), + + outboundTimes: make(map[uint32]time.Time), + inboundTimes: make(map[uint32]time.Time), + + windowSize: 8, + rtt: (1 * time.Millisecond).Seconds(), + slowStart: true, + } + + go func() { + var s uint32 + for { + if s == 0 { + s++ + continue + } + + c.sequence <- s + s++ + } + }() + + return &c +} + +func (c *NewReno) Reset() { + c.outboundTimes = make(map[uint32]time.Time) + c.inboundTimes = make(map[uint32]time.Time) + c.windowSize = 8 + c.rtt = (1 * time.Millisecond).Seconds() + c.slowStart = true + c.hasAcked = false +} + +// It is assumed that ReceivedAck will only be called by one thread +func (c *NewReno) ReceivedAck(ack uint32) { + c.outboundTimesLock.Lock() + defer c.outboundTimesLock.Unlock() + + log.Printf("ack received for %d", ack) + c.hasAcked = true + + // RTT + // Update using an exponential average + rtt := time.Now().Sub(c.outboundTimes[ack]).Seconds() + + delete(c.outboundTimes, ack) + c.rtt = c.rtt*(1-RttExponentialFactor) + rtt*RttExponentialFactor + + // Free Window + atomic.AddInt32(&c.inFlight, -1) + select { + case c.ackNotifier <- struct{}{}: + default: + } + + // GROW + // CASE: exponential. increase window size by one per ack + // CASE: standard. increase window size by one per window of acks + if c.slowStart { + atomic.AddInt32(&c.windowSize, 1) + } else { + c.windowCount++ + if c.windowCount == c.windowSize { + c.windowCount = 0 + atomic.AddInt32(&c.windowSize, 1) + } + } +} + +// It is assumed that ReceivedNack will only be called by one thread +func (c *NewReno) ReceivedNack(nack uint32) { + log.Printf("nack received for %d", nack) + + // End slow start + c.slowStart = false + if s := c.windowSize; s > 1 { + atomic.StoreInt32(&c.windowSize, s/2) + } +} + +func (c *NewReno) ReceivedPacket(seq uint32) { + log.Printf("seq received for %d", seq) + + c.inboundTimes[seq] = time.Now() + + c.acksToSendLock.Lock() + c.acksToSend.Insert(seq) + c.acksToSendLock.Unlock() + + c.updateAckNack() +} + +func (c *NewReno) updateAckNack() { + c.acksToSendLock.Lock() + defer c.acksToSendLock.Unlock() + + c.inboundTimesLock.Lock() + defer c.inboundTimesLock.Unlock() + + findAck := func(start uint32) uint32 { + ack := start + for len(c.acksToSend) > 0 { + if a, _ := c.acksToSend.Peek(); a == ack+1 { + ack, _ = c.acksToSend.Extract() + delete(c.inboundTimes, ack) + } else { + break + } + } + return ack + } + + ack := findAck(c.ack) + if ack == c.ack { + // check if there is a nack to send + // decide this based on whether there have been 3RTTs between the offset packet + if len(c.acksToSend) > 0 { + nextAck, _ := c.acksToSend.Peek() + if time.Now().Sub(c.inboundTimes[nextAck]).Seconds() > c.rtt*3 { + atomic.StoreUint32(&c.nack, nextAck-1) + ack, _ = c.acksToSend.Extract() + ack = findAck(ack) + } + } + } + + atomic.StoreUint32(&c.ack, ack) +} + +func (c *NewReno) Sequence() uint32 { + c.outboundTimesLock.Lock() + defer c.outboundTimesLock.Unlock() + + for c.inFlight >= c.windowSize { + <-c.ackNotifier + } + atomic.AddInt32(&c.inFlight, 1) + + s := <-c.sequence + + n := time.Now() + c.lastSent = n + c.outboundTimes[s] = n + + return s +} + +func (c *NewReno) NextAck() uint32 { + a := c.ack + c.lastAck = a + return a +} + +func (c *NewReno) NextNack() uint32 { + n := c.nack + c.lastNack = n + return n +} + +func (c *NewReno) AwaitEarlyUpdate(keepalive time.Duration) uint32 { + for { + rtt := time.Duration(math.Round(c.rtt * float64(time.Second))) + time.Sleep(rtt) + + c.updateAckNack() + + // CASE 1: waiting ACKs or NACKs and no message sent in the last RTT + if ((c.lastAck != c.ack) || (c.lastNack != c.nack)) && time.Now().After(c.lastSent.Add(rtt)) { + return 0 + } + + // CASE 3: No message sent within the keepalive time + if keepalive != 0 && time.Now().After(c.lastSent.Add(keepalive)) { + return c.Sequence() + } + } +} diff --git a/udp/congestion/none.go b/udp/congestion/none.go new file mode 100644 index 0000000..4640a04 --- /dev/null +++ b/udp/congestion/none.go @@ -0,0 +1,44 @@ +package congestion + +import ( + "fmt" + "time" +) + +type None struct { + sequence chan uint32 +} + +func NewNone() *None { + c := None{ + sequence: make(chan uint32), + } + + go func() { + var s uint32 + for { + if s == 0 { + s++ + continue + } + + c.sequence <- s + s++ + } + }() + + return &c +} + +func (c *None) String() string { + return fmt.Sprintf("{None}") +} + +func (c *None) ReceivedPacket(uint32) {} +func (c *None) ReceivedAck(uint32) {} +func (c *None) ReceivedNack(uint32) {} +func (c *None) Reset() {} +func (c *None) NextNack() uint32 { return 0 } +func (c *None) NextAck() uint32 { return 0 } +func (c *None) AwaitEarlyUpdate(time.Duration) uint32 { select {} } +func (c *None) Sequence() uint32 { return <-c.sequence } diff --git a/udp/flow.go b/udp/flow.go new file mode 100644 index 0000000..e6e01c7 --- /dev/null +++ b/udp/flow.go @@ -0,0 +1,285 @@ +package udp + +import ( + "errors" + "fmt" + "log" + "mpbl3p/proxy" + "mpbl3p/shared" + "net" + "sync" + "time" +) + +type PacketWriter interface { + Write(b []byte) (int, error) + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + LocalAddr() net.Addr +} + +type PacketConn interface { + PacketWriter + ReadFromUDP(b []byte) (int, *net.UDPAddr, error) +} + +type InitiatedFlow struct { + Local string + Remote string + + g proxy.MacGenerator + keepalive time.Duration + + mu sync.RWMutex + Flow +} + +func (f *InitiatedFlow) String() string { + return fmt.Sprintf("UdpOutbound{%v -> %v}", f.Local, f.Remote) +} + +type Flow struct { + writer PacketWriter + raddr *net.UDPAddr + + isAlive bool + startup bool + congestion Congestion + + v proxy.MacVerifier + + inboundDatagrams chan []byte +} + +func (f Flow) String() string { + return fmt.Sprintf("UdpInbound{%v -> %v}", f.raddr, f.writer.LocalAddr()) +} + +func InitiateFlow( + local, remote string, + v proxy.MacVerifier, + g proxy.MacGenerator, + c Congestion, + keepalive time.Duration, +) (*InitiatedFlow, error) { + f := InitiatedFlow{ + Local: local, + Remote: remote, + Flow: newFlow(c, v), + g: g, + keepalive: keepalive, + } + + return &f, nil +} + +func newFlow(c Congestion, v proxy.MacVerifier) Flow { + return Flow{ + inboundDatagrams: make(chan []byte), + congestion: c, + v: v, + } +} + +func (f *InitiatedFlow) Reconnect() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.isAlive { + return nil + } + + localAddr, err := net.ResolveUDPAddr("udp", f.Local) + if err != nil { + return err + } + + remoteAddr, err := net.ResolveUDPAddr("udp", f.Remote) + if err != nil { + return err + } + + conn, err := net.DialUDP("udp", localAddr, remoteAddr) + if err != nil { + return err + } + + f.writer = conn + f.startup = true + + // prod the connection once a second until we get an ack, then consider it alive + go func() { + seq := f.congestion.Sequence() + + for !f.isAlive { + p := Packet{ + ack: 0, + nack: 0, + seq: seq, + data: proxy.NewSimplePacket(nil), + } + + _ = f.sendPacket(p, f.g) + } + }() + + go func() { + _, _ = f.produceInternal(f.v, false) + }() + go f.earlyUpdateLoop(f.g, f.keepalive) + + if err := f.acceptPacket(conn); err != nil { + return err + } + + f.isAlive = true + f.startup = false + + go func() { + lockedAccept := func() { + f.mu.RLock() + defer f.mu.RUnlock() + + if err := f.acceptPacket(conn); err != nil { + log.Println(err) + } + } + + for f.isAlive { + log.Println("alive and listening for packets") + lockedAccept() + } + log.Println("no longer alive") + }() + + return nil +} + +func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Consume(p, g) +} + +func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.Flow.Produce(v) +} + +func (f *Flow) IsAlive() bool { + return f.isAlive +} + +func (f *Flow) Consume(pp proxy.Packet, g proxy.MacGenerator) error { + if !f.isAlive { + return shared.ErrDeadConnection + } + + log.Println(f.congestion) + + // Sequence is the congestion controllers opportunity to block + log.Println("awaiting sequence") + p := Packet{ + seq: f.congestion.Sequence(), + data: pp, + } + log.Println("received sequence") + + // Choose up to date ACK/NACK even after blocking + p.ack = f.congestion.NextAck() + p.nack = f.congestion.NextNack() + + return f.sendPacket(p, g) +} + +func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { + if !f.isAlive { + return nil, shared.ErrDeadConnection + } + + return f.produceInternal(v, true) +} + +func (f *Flow) produceInternal(v proxy.MacVerifier, mustReturn bool) (proxy.Packet, error) { + for once := true; mustReturn || once; once = false { + log.Println(f.congestion) + + b, err := proxy.StripMac(<-f.inboundDatagrams, v) + if err != nil { + return nil, err + } + + p, err := UnmarshalPacket(b) + if err != nil { + return nil, err + } + + // schedule an ack for this sequence number + if p.seq != 0 { + f.congestion.ReceivedPacket(p.seq) + } + // adjust our sending congestion control based on their acks + if p.ack != 0 { + f.congestion.ReceivedAck(p.ack) + } + // adjust our sending congestion control based on their nacks + if p.nack != 0 { + f.congestion.ReceivedNack(p.nack) + } + + // 12 bytes for header + the MAC + a timestamp + if len(b) == 12+f.v.CodeLength()+8 { + log.Println("handled keepalive/ack only packet") + continue + } + + return p, nil + } + + return nil, nil +} + +func (f *Flow) handleDatagram(p []byte) { + f.inboundDatagrams <- p +} + +func (f *Flow) sendPacket(p Packet, g proxy.MacGenerator) error { + b := p.Marshal() + b = proxy.AppendMac(b, g) + + if f.raddr == nil { + _, err := f.writer.Write(b) + return err + } else { + _, err := f.writer.WriteToUDP(b, f.raddr) + return err + } +} + +func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) { + var err error + for !errors.Is(err, shared.ErrDeadConnection) { + seq := f.congestion.AwaitEarlyUpdate(keepalive) + p := Packet{ + ack: f.congestion.NextAck(), + nack: f.congestion.NextNack(), + seq: seq, + data: proxy.NewSimplePacket(nil), + } + + _ = f.sendPacket(p, g) + } +} + +func (f *Flow) acceptPacket(c PacketConn) error { + buf := make([]byte, 6000) + n, _, err := c.ReadFromUDP(buf) + if err != nil { + return err + } + + f.handleDatagram(buf[:n]) + return nil +} diff --git a/udp/flow_test.go b/udp/flow_test.go new file mode 100644 index 0000000..e102ffe --- /dev/null +++ b/udp/flow_test.go @@ -0,0 +1,85 @@ +package udp + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "mpbl3p/mocks" + "mpbl3p/proxy" + "mpbl3p/udp/congestion" + "testing" + "time" +) + +func TestFlow_Consume(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := proxy.NewSimplePacket(testContent) + testMac := mocks.AlmostUselessMac{} + + t.Run("Length", func(t *testing.T) { + testConn := mocks.NewMockPerfectBiPacketConn(10) + + flowA := newFlow(congestion.NewNone(), testMac) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + err := flowA.Consume(testPacket, testMac) + require.Nil(t, err) + + buf := make([]byte, 100) + n, _, err := testConn.SideA().ReadFromUDP(buf) + require.Nil(t, err) + + // 12 header, 8 timestamp, 4 MAC + assert.Equal(t, len(testContent)+12+8+4, n) + }) +} + +func TestFlow_Produce(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 42, + nack: 26, + seq: 128, + data: proxy.NewSimplePacket(testContent), + } + testMac := mocks.AlmostUselessMac{} + + testMarshalled := proxy.AppendMac(testPacket.Marshal(), testMac) + + t.Run("Length", func(t *testing.T) { + done := make(chan struct{}) + + go func() { + testConn := mocks.NewMockPerfectBiPacketConn(10) + + _, err := testConn.SideA().Write(testMarshalled) + require.Nil(t, err) + + flowA := newFlow(congestion.NewNone(), testMac) + + flowA.writer = testConn.SideB() + flowA.isAlive = true + + go func() { + err := flowA.acceptPacket(testConn.SideB()) + assert.Nil(t, err) + }() + p, err := flowA.Produce(testMac) + + require.Nil(t, err) + assert.Len(t, p.Contents(), len(testContent)) + + done <- struct{}{} + }() + + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-done: + case <-timer.C: + fmt.Println("timed out") + t.FailNow() + } + }) +} diff --git a/udp/listener.go b/udp/listener.go new file mode 100644 index 0000000..847731d --- /dev/null +++ b/udp/listener.go @@ -0,0 +1,88 @@ +package udp + +import ( + "log" + "mpbl3p/proxy" + "net" +) + +type ComparableUdpAddress struct { + IP [16]byte + Port int + Zone string +} + +func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { + var ip [16]byte + for i, b := range []byte(address.IP) { + ip[i] = b + } + + return ComparableUdpAddress{ + IP: ip, + Port: address.Port, + Zone: address.Zone, + } +} + +func NewListener(p *proxy.Proxy, local string, v proxy.MacVerifier, g proxy.MacGenerator, c func() Congestion) error { + laddr, err := net.ResolveUDPAddr("udp", local) + if err != nil { + return err + } + + pconn, err := net.ListenUDP("udp", laddr) + if err != nil { + return err + } + + err = pconn.SetWriteBuffer(0) + if err != nil { + panic(err) + } + + receivedConnections := make(map[ComparableUdpAddress]*Flow) + + go func() { + for { + buf := make([]byte, 6000) + + log.Println("listening...") + n, addr, err := pconn.ReadFromUDP(buf) + if err != nil { + panic(err) + } + log.Println("listened") + + raddr := fromUdpAddress(*addr) + if f, exists := receivedConnections[raddr]; exists { + log.Println("existing flow") + log.Println("handling...") + f.handleDatagram(buf[:n]) + log.Println("handled") + continue + } + + f := newFlow(c(), v) + + f.writer = pconn + f.raddr = addr + f.isAlive = true + + log.Printf("received new udp connection: %v\n", f) + + go f.earlyUpdateLoop(g, 0) + + receivedConnections[raddr] = &f + + p.AddConsumer(&f) + p.AddProducer(&f, v) + + log.Println("handling...") + f.handleDatagram(buf[:n]) + log.Println("handled") + } + }() + + return nil +} diff --git a/udp/packet.go b/udp/packet.go new file mode 100644 index 0000000..8ab0091 --- /dev/null +++ b/udp/packet.go @@ -0,0 +1,41 @@ +package udp + +import ( + "encoding/binary" + "mpbl3p/proxy" +) + +type Packet struct { + ack uint32 + nack uint32 + seq uint32 + + data proxy.Packet +} + +func UnmarshalPacket(b []byte) (p Packet, err error) { + p.ack = binary.LittleEndian.Uint32(b[0:4]) + p.nack = binary.LittleEndian.Uint32(b[4:8]) + p.seq = binary.LittleEndian.Uint32(b[8:12]) + + p.data, err = proxy.UnmarshalSimplePacket(b[12:]) + if err != nil { + return Packet{}, err + } + return p, nil +} + +func (p Packet) Marshal() []byte { + data := p.data.Marshal() + header := make([]byte, 12) + + binary.LittleEndian.PutUint32(header[0:4], p.ack) + binary.LittleEndian.PutUint32(header[4:8], p.nack) + binary.LittleEndian.PutUint32(header[8:12], p.seq) + + return append(header, data...) +} + +func (p Packet) Contents() []byte { + return p.data.Contents() +} diff --git a/udp/packet_test.go b/udp/packet_test.go new file mode 100644 index 0000000..8770774 --- /dev/null +++ b/udp/packet_test.go @@ -0,0 +1,59 @@ +package udp + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "mpbl3p/proxy" + "testing" +) + +func TestPacket_Marshal(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 18, + nack: 29, + seq: 431, + data: proxy.NewSimplePacket(testContent), + } + + t.Run("Length", func(t *testing.T) { + marshalled := testPacket.Marshal() + + // 12 header + 8 timestamp + assert.Len(t, marshalled, len(testContent)+12+8) + }) +} + +func TestUnmarshalPacket(t *testing.T) { + testContent := []byte("A test string is the content of this packet.") + testPacket := Packet{ + ack: 18, + nack: 29, + seq: 431, + data: proxy.NewSimplePacket(testContent), + } + testMarshalled := testPacket.Marshal() + + t.Run("Length", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + + require.Nil(t, err) + assert.Len(t, p.Contents(), len(testContent)) + }) + + t.Run("Contents", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + + require.Nil(t, err) + assert.Equal(t, p.Contents(), testContent) + }) + + t.Run("Header", func(t *testing.T) { + p, err := UnmarshalPacket(testMarshalled) + require.Nil(t, err) + + assert.Equal(t, p.ack, uint32(18)) + assert.Equal(t, p.nack, uint32(29)) + assert.Equal(t, p.seq, uint32(431)) + }) +} diff --git a/udp/wireshark_dissector.lua b/udp/wireshark_dissector.lua new file mode 100644 index 0000000..95f1c4b --- /dev/null +++ b/udp/wireshark_dissector.lua @@ -0,0 +1,35 @@ +mpbl3p_udp = Proto("mpbl3p_udp", "Multi Path Proxy Custom UDP") + +ack_F = ProtoField.uint32("mpbl3p_udp.ack", "Acknowledgement") +nack_F = ProtoField.uint32("mpbl3p_udp.nack", "Negative Acknowledgement") +seq_F = ProtoField.uint32("mpbl3p_udp.seq", "Sequence Number") +time_F = ProtoField.absolute_time("mpbl3p_udp.time", "Timestamp") +proxied_F = ProtoField.bytes("mpbl3p_udp.data", "Proxied Data") + +mpbl3p_udp.fields = { ack_F, nack_F, seq_F, time_F, proxied_F } + +function mpbl3p_udp.dissector(buffer, pinfo, tree) + if buffer:len() < 20 then + return + end + + pinfo.cols.protocol = "MPBL3P_UDP" + + local ack = buffer(0, 4):le_uint() + local nack = buffer(4, 4):le_uint() + local seq = buffer(8, 4):le_uint() + + local unix_time = buffer(buffer:len() - 8, 8):le_uint64() + + local subtree = tree:add(mpbl3p_udp, buffer(), "Multi Path Proxy Header, SEQ: " .. seq .. " ACK: " .. ack .. " NACK: " .. nack) + + subtree:add(ack_F, ack) + subtree:add(nack_F, nack) + subtree:add(seq_F, seq) + subtree:add(time_F, NSTime.new(unix_time:tonumber())) + if buffer:len() > 20 then + subtree:add(proxied_F, buffer(12, buffer:len() - 12 - 8)) + end +end + +DissectorTable.get("udp.port"):add(1234, mpbl3p_udp) diff --git a/utils/heap.go b/utils/heap.go new file mode 100644 index 0000000..96f2c34 --- /dev/null +++ b/utils/heap.go @@ -0,0 +1,65 @@ +package utils + +import "errors" + +var ErrorEmptyHeap = errors.New("attempted to extract from empty heap") + +// A MinHeap for Uint64 +type Uint32Heap []uint32 + +func (h *Uint32Heap) swap(x, y int) { + (*h)[x] = (*h)[x] ^ (*h)[y] + (*h)[y] = (*h)[y] ^ (*h)[x] + (*h)[x] = (*h)[x] ^ (*h)[y] +} + +func (h *Uint32Heap) Insert(new uint32) uint32 { + *h = append(*h, new) + + child := len(*h) - 1 + for child != 0 { + parent := (child - 1) / 2 + if (*h)[parent] > (*h)[child] { + h.swap(parent, child) + } else { + break + } + child = parent + } + + return (*h)[0] +} + +func (h *Uint32Heap) Extract() (uint32, error) { + if len(*h) == 0 { + return 0, ErrorEmptyHeap + } + min := (*h)[0] + + (*h)[0] = (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + + parent := 0 + for { + left, right := parent*2+1, parent*2+2 + + if (left < len(*h) && (*h)[parent] > (*h)[left]) || (right < len(*h) && (*h)[parent] > (*h)[right]) { + if right < len(*h) && (*h)[left] > (*h)[right] { + h.swap(parent, right) + parent = right + } else { + h.swap(parent, left) + parent = left + } + } else { + return min, nil + } + } +} + +func (h *Uint32Heap) Peek() (uint32, error) { + if len(*h) == 0 { + return 0, ErrorEmptyHeap + } + return (*h)[0], nil +} diff --git a/utils/heap_test.go b/utils/heap_test.go new file mode 100644 index 0000000..61a1bbd --- /dev/null +++ b/utils/heap_test.go @@ -0,0 +1,54 @@ +package utils + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" + "time" +) + +func SlowHeapSort(in []uint32) []uint32 { + out := make([]uint32, len(in)) + + var heap Uint32Heap + + for _, x := range in { + heap.Insert(x) + } + for i := range out { + var err error + out[i], err = heap.Extract() + if err != nil { + panic(err) + } + } + + return out +} + +func TestUint32Heap(t *testing.T) { + t.Run("EquivalentToMerge", func(t *testing.T) { + const ArrayLength = 50 + + sortedArray := make([]uint32, ArrayLength) + array := make([]uint32, ArrayLength) + + for i := range array { + sortedArray[i] = uint32(i) + array[i] = uint32(i) + } + + rand.Seed(time.Now().Unix()) + + for i := 0; i < 100; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + rand.Shuffle(50, func(i, j int) { array[i], array[j] = array[j], array[i] }) + + heapSorted := SlowHeapSort(array) + + assert.Equal(t, sortedArray, heapSorted) + }) + } + }) +} diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index 2548b3b..0000000 --- a/utils/utils.go +++ /dev/null @@ -1,13 +0,0 @@ -package utils - -var NextId = make(chan int) - -func init() { - go func() { - i := 0 - for { - NextId <- i - i += 1 - } - }() -}