diff --git a/proxy/packet.go b/proxy/packet.go index 260b9f9..f24115f 100644 --- a/proxy/packet.go +++ b/proxy/packet.go @@ -10,49 +10,26 @@ type Packet interface { Contents() []byte } -type SimplePacket struct { - Data []byte - timestamp time.Time -} - -// create a packet from the raw data of an IP packet -func NewSimplePacket(data []byte) Packet { - return SimplePacket{ - Data: data, - timestamp: time.Now(), - } -} - -// rebuild a packet from the wrapped format -func UnmarshalSimplePacket(data []byte) (SimplePacket, error) { - p := SimplePacket{ - Data: data[:len(data)-8], - } - - unixTime := int64(binary.LittleEndian.Uint64(data[len(data)-8:])) - p.timestamp = time.Unix(unixTime, 0) - - return p, nil -} +type SimplePacket []byte // get the raw data of the IP packet func (p SimplePacket) Marshal() []byte { - footer := make([]byte, 8) - - unixTime := uint64(p.timestamp.Unix()) - binary.LittleEndian.PutUint64(footer, unixTime) - - return append(p.Data, footer...) + return p } func (p SimplePacket) Contents() []byte { - return p.Data + return p } func AppendMac(b []byte, g MacGenerator) []byte { + footer := make([]byte, 8) + unixTime := uint64(time.Now().Unix()) + binary.LittleEndian.PutUint64(footer, unixTime) + + b = append(b, footer...) + mac := g.Generate(b) - b = append(b, mac...) - return b + return append(b, mac...) } func StripMac(b []byte, v MacVerifier) ([]byte, error) { @@ -63,5 +40,5 @@ func StripMac(b []byte, v MacVerifier) ([]byte, error) { return nil, err } - return data, nil + return data[:len(data)-8], nil } diff --git a/proxy/packet_test.go b/proxy/packet_test.go index d97a79e..af2d626 100644 --- a/proxy/packet_test.go +++ b/proxy/packet_test.go @@ -8,51 +8,20 @@ import ( "testing" ) -func TestPacket_Marshal(t *testing.T) { - testContent := []byte("A test string is the content of this packet.") - testPacket := NewSimplePacket(testContent) - - t.Run("Length", func(t *testing.T) { - marshalled := testPacket.Marshal() - - assert.Len(t, marshalled, len(testContent)+8) - }) -} - -func TestUnmarshalPacket(t *testing.T) { - testContent := []byte("A test string is the content of this packet.") - testPacket := NewSimplePacket(testContent) - testMarshalled := testPacket.Marshal() - - t.Run("Length", func(t *testing.T) { - p, err := UnmarshalSimplePacket(testMarshalled) - - require.Nil(t, err) - assert.Len(t, p.Contents(), len(testContent)) - }) - - t.Run("Contents", func(t *testing.T) { - p, err := UnmarshalSimplePacket(testMarshalled) - - require.Nil(t, err) - assert.Equal(t, p.Contents(), testContent) - }) -} - func TestAppendMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") testMac := mocks.AlmostUselessMac{} - testPacket := NewSimplePacket(testContent) + testPacket := SimplePacket(testContent) testMarshalled := testPacket.Marshal() appended := AppendMac(testMarshalled, testMac) t.Run("Length", func(t *testing.T) { - assert.Len(t, appended, len(testMarshalled)+4) + assert.Len(t, appended, len(testMarshalled)+8+4) }) t.Run("Mac", func(t *testing.T) { - assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled):]) + assert.Equal(t, []byte{'a', 'b', 'c', 'd'}, appended[len(testMarshalled)+8:]) }) t.Run("Original", func(t *testing.T) { @@ -63,7 +32,7 @@ func TestAppendMac(t *testing.T) { func TestStripMac(t *testing.T) { testContent := []byte("A test string is the content of this packet.") testMac := mocks.AlmostUselessMac{} - testPacket := NewSimplePacket(testContent) + testPacket := SimplePacket(testContent) testMarshalled := testPacket.Marshal() appended := AppendMac(testMarshalled, testMac) diff --git a/shared/errors.go b/shared/errors.go index aaae8ba..0db0c92 100644 --- a/shared/errors.go +++ b/shared/errors.go @@ -4,3 +4,4 @@ import "errors" var ErrBadChecksum = errors.New("the packet had a bad checksum") var ErrDeadConnection = errors.New("the connection is dead") +var ErrNotEnoughBytes = errors.New("not enough bytes") diff --git a/tcp/flow.go b/tcp/flow.go index c626e49..057b233 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -2,7 +2,6 @@ package tcp import ( "encoding/binary" - "errors" "fmt" "io" "mpbl3p/proxy" @@ -12,8 +11,6 @@ import ( "time" ) -var ErrNotEnoughBytes = errors.New("not enough bytes") - type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) @@ -150,7 +147,7 @@ func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { return nil, err } - return proxy.UnmarshalSimplePacket(b) + return proxy.SimplePacket(b), nil } func (f *Flow) produceMarshalled() ([]byte, error) { @@ -158,7 +155,7 @@ func (f *Flow) produceMarshalled() ([]byte, error) { if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil { return nil, err } else if n != 4 { - return nil, ErrNotEnoughBytes + return nil, shared.ErrNotEnoughBytes } length := binary.LittleEndian.Uint32(lengthBytes) diff --git a/tcp/flow_test.go b/tcp/flow_test.go index 6e0e7b3..17214e0 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -11,7 +11,7 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := proxy.NewSimplePacket(testContent) + testPacket := proxy.SimplePacket(testContent) testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { diff --git a/tun/tun.go b/tun/tun.go index 1b55eb1..b4141f2 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -68,7 +68,7 @@ func (t *SourceSink) Source() (proxy.Packet, error) { return nil, io.EOF } - return proxy.NewSimplePacket(buf[:read]), nil + return proxy.SimplePacket(buf[:read]), nil } var good, bad float64 diff --git a/udp/flow.go b/udp/flow.go index e6e01c7..d8c842e 100644 --- a/udp/flow.go +++ b/udp/flow.go @@ -115,7 +115,7 @@ func (f *InitiatedFlow) Reconnect() error { ack: 0, nack: 0, seq: seq, - data: proxy.NewSimplePacket(nil), + data: proxy.SimplePacket(nil), } _ = f.sendPacket(p, f.g) @@ -266,7 +266,7 @@ func (f *Flow) earlyUpdateLoop(g proxy.MacGenerator, keepalive time.Duration) { ack: f.congestion.NextAck(), nack: f.congestion.NextNack(), seq: seq, - data: proxy.NewSimplePacket(nil), + data: proxy.SimplePacket(nil), } _ = f.sendPacket(p, g) diff --git a/udp/flow_test.go b/udp/flow_test.go index e102ffe..b2f20c2 100644 --- a/udp/flow_test.go +++ b/udp/flow_test.go @@ -13,7 +13,7 @@ import ( func TestFlow_Consume(t *testing.T) { testContent := []byte("A test string is the content of this packet.") - testPacket := proxy.NewSimplePacket(testContent) + testPacket := proxy.SimplePacket(testContent) testMac := mocks.AlmostUselessMac{} t.Run("Length", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestFlow_Produce(t *testing.T) { ack: 42, nack: 26, seq: 128, - data: proxy.NewSimplePacket(testContent), + data: proxy.SimplePacket(testContent), } testMac := mocks.AlmostUselessMac{} diff --git a/udp/packet.go b/udp/packet.go index 8ab0091..08757d8 100644 --- a/udp/packet.go +++ b/udp/packet.go @@ -3,6 +3,7 @@ package udp import ( "encoding/binary" "mpbl3p/proxy" + "mpbl3p/shared" ) type Packet struct { @@ -14,14 +15,15 @@ type Packet struct { } func UnmarshalPacket(b []byte) (p Packet, err error) { + if len(b) < 12 { + return Packet{}, shared.ErrNotEnoughBytes + } + p.ack = binary.LittleEndian.Uint32(b[0:4]) p.nack = binary.LittleEndian.Uint32(b[4:8]) p.seq = binary.LittleEndian.Uint32(b[8:12]) - p.data, err = proxy.UnmarshalSimplePacket(b[12:]) - if err != nil { - return Packet{}, err - } + p.data = proxy.SimplePacket(b[12:]) return p, nil } diff --git a/udp/packet_test.go b/udp/packet_test.go index 8770774..e1ded3d 100644 --- a/udp/packet_test.go +++ b/udp/packet_test.go @@ -13,14 +13,14 @@ func TestPacket_Marshal(t *testing.T) { ack: 18, nack: 29, seq: 431, - data: proxy.NewSimplePacket(testContent), + data: proxy.SimplePacket(testContent), } t.Run("Length", func(t *testing.T) { marshalled := testPacket.Marshal() // 12 header + 8 timestamp - assert.Len(t, marshalled, len(testContent)+12+8) + assert.Len(t, marshalled, len(testContent)+12) }) } @@ -30,7 +30,7 @@ func TestUnmarshalPacket(t *testing.T) { ack: 18, nack: 29, seq: 431, - data: proxy.NewSimplePacket(testContent), + data: proxy.SimplePacket(testContent), } testMarshalled := testPacket.Marshal()