diff --git a/.gitignore b/.gitignore index 7d235d7..42e21ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -config.ini +*.conf logs/ # Created by https://www.toptal.com/developers/gitignore/api/intellij+all,go diff --git a/config/builder.go b/config/builder.go index b83995c..c5bea4f 100644 --- a/config/builder.go +++ b/config/builder.go @@ -75,13 +75,17 @@ func buildTcp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac return err } - p.AddConsumer(ctx, f, g()) - p.AddProducer(ctx, f, v()) + if !peer.DisableConsumer { + p.AddConsumer(ctx, f, g()) + } + if !peer.DisableProducer { + p.AddProducer(ctx, f, v()) + } return nil } - err := tcp.NewListener(ctx, p, laddr(), v, g) + err := tcp.NewListener(ctx, p, laddr(), v, g, !peer.DisableConsumer, !peer.DisableProducer) if err != nil { return err } @@ -121,13 +125,17 @@ func buildUdp(ctx context.Context, p *proxy.Proxy, peer Peer, g func() proxy.Mac return err } - p.AddConsumer(ctx, f, g()) - p.AddProducer(ctx, f, v()) + if !peer.DisableConsumer { + p.AddConsumer(ctx, f, g()) + } + if !peer.DisableProducer { + p.AddProducer(ctx, f, v()) + } return nil } - err := udp.NewListener(ctx, p, laddr(), v, g, c) + err := udp.NewListener(ctx, p, laddr(), v, g, c, !peer.DisableConsumer, !peer.DisableProducer) if err != nil { return err } diff --git a/config/config.go b/config/config.go index 1328116..ed7bea5 100644 --- a/config/config.go +++ b/config/config.go @@ -57,6 +57,9 @@ type Peer struct { KeepAlive uint Timeout uint RetryWait uint + + DisableConsumer bool `validate:"omitempty,nefield=DisableProducer"` + DisableProducer bool `validate:"omitempty,nefield=DisableConsumer"` } func (p Peer) GetLocalHost() string { diff --git a/crypto/sharedkey/blake2s_test.go b/crypto/sharedkey/blake2s_test.go new file mode 100644 index 0000000..f6cc546 --- /dev/null +++ b/crypto/sharedkey/blake2s_test.go @@ -0,0 +1,45 @@ +package sharedkey + +import ( + "github.com/stretchr/testify/assert" + "math/rand" + "mpbl3p/shared" + "testing" +) + +func TestBlake2s_GenerateVerify(t *testing.T) { + t.Run("GeneratedVerifies", func(t *testing.T) { + // ASSIGN + key := make([]byte, 16) + rand.Read(key) + buf := make([]byte, 500) + rand.Read(buf) + + // ACT + b := Blake2s{key} + code := b.Generate(buf) + + // ASSERT + err := b.Verify(buf, code) + assert.Nil(t, err) + }) + + t.Run("FlippedBitFailsVerify", func(t *testing.T) { + // ASSIGN + key := make([]byte, 16) + rand.Read(key) + buf := make([]byte, 500) + rand.Read(buf) + + // ACT + b := Blake2s{key} + code := b.Generate(buf) + + offset := rand.Intn(len(buf) * 8) + buf[offset/8] ^= 1 << (offset % 8) + + // ASSERT + err := b.Verify(buf, code) + assert.Equal(t, shared.ErrBadChecksum, err) + }) +} diff --git a/go.mod b/go.mod index 7682057..7a73cee 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module mpbl3p go 1.15 require ( - github.com/go-playground/validator/v10 v10.4.1 + github.com/go-playground/validator/v10 v10.6.0 github.com/jessevdk/go-flags v1.5.0 github.com/smartystreets/goconvey v1.6.4 // indirect github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 04dd5cd..a201b0b 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.6.0 h1:UGIt4xR++fD9QrBOoo/ascJfGe3AGHEB9s6COnss4Rk= +github.com/go-playground/validator/v10 v10.6.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= @@ -49,6 +51,7 @@ golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXR golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/main.go b/main.go index 0c31e25..db3e4b4 100644 --- a/main.go +++ b/main.go @@ -42,7 +42,8 @@ func main() { c, err := config.LoadConfig(o.ConfigFile) if err != nil { - panic(err) + log.Fatalf("error validating config: %s", err.Error()) + return } log.Println("creating tun adapter...") diff --git a/tcp/flow.go b/tcp/flow.go index 01d76c2..c466377 100644 --- a/tcp/flow.go +++ b/tcp/flow.go @@ -53,7 +53,7 @@ func NewFlow() Flow { } } -func NewFlowConn(conn Conn) Flow { +func NewFlowConn(ctx context.Context, conn Conn) Flow { f := Flow{ conn: conn, isAlive: true, @@ -64,8 +64,8 @@ func NewFlowConn(conn Conn) Flow { produceErrors: make(chan error), } - go f.produceMarshalled() - go f.consumeMarshalled() + go f.produceMarshalled(ctx) + go f.consumeMarshalled(ctx) return f } @@ -89,7 +89,7 @@ func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) { return &f, nil } -func (f *InitiatedFlow) Reconnect() error { +func (f *InitiatedFlow) Reconnect(ctx context.Context) error { f.mu.Lock() defer f.mu.Unlock() @@ -115,8 +115,8 @@ func (f *InitiatedFlow) Reconnect() error { f.conn = conn f.isAlive = true - go f.produceMarshalled() - go f.consumeMarshalled() + go f.produceMarshalled(ctx) + go f.consumeMarshalled(ctx) return nil } @@ -187,7 +187,7 @@ func (f *Flow) Produce(ctx context.Context, v proxy.MacVerifier) (proxy.Packet, return proxy.SimplePacket(b), nil } -func (f *Flow) consumeMarshalled() { +func (f *Flow) consumeMarshalled(ctx context.Context) { for { data := <-f.toConsume @@ -204,7 +204,7 @@ func (f *Flow) consumeMarshalled() { } } -func (f *Flow) produceMarshalled() { +func (f *Flow) produceMarshalled(ctx context.Context) { buf := bufio.NewReader(f.conn) for { diff --git a/tcp/flow_test.go b/tcp/flow_test.go index 3d0b49c..eba5a0a 100644 --- a/tcp/flow_test.go +++ b/tcp/flow_test.go @@ -18,7 +18,7 @@ func TestFlow_Consume(t *testing.T) { t.Run("Length", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA()) err := flowA.Consume(context.Background(), testPacket, testMac) require.Nil(t, err) @@ -42,7 +42,7 @@ func TestFlow_Produce(t *testing.T) { t.Run("Length", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA()) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) @@ -55,7 +55,7 @@ func TestFlow_Produce(t *testing.T) { t.Run("Value", func(t *testing.T) { testConn := mocks.NewMockPerfectBiStreamConn(100) - flowA := NewFlowConn(testConn.SideA()) + flowA := NewFlowConn(context.Background(), testConn.SideA()) _, err := testConn.SideB().Write(testMarshalled) require.Nil(t, err) diff --git a/tcp/listener.go b/tcp/listener.go index 0d907ed..c99152d 100644 --- a/tcp/listener.go +++ b/tcp/listener.go @@ -7,7 +7,7 @@ import ( "net" ) -func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator) error { +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, enableConsumers bool, enableProducers bool) error { laddr, err := net.ResolveTCPAddr("tcp", local) if err != nil { return err @@ -25,12 +25,16 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro panic(err) } - f := NewFlowConn(conn) + f := NewFlowConn(ctx, conn) log.Printf("received new tcp connection: %v\n", f) - p.AddConsumer(ctx, &f, g()) - p.AddProducer(ctx, &f, v()) + if enableConsumers { + p.AddConsumer(ctx, &f, g()) + } + if enableProducers { + p.AddProducer(ctx, &f, v()) + } } }() diff --git a/tun/tun.go b/tun/tun.go index 7cdecad..252140e 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -10,7 +10,6 @@ import ( type SourceSink struct { tun wgtun.Device - mtu int } func NewTun(name string, mtu int) (t wgtun.Device, err error) { @@ -26,7 +25,6 @@ func NewFromFile(fd uintptr, mtu int) (ss *SourceSink, err error) { return } - ss.mtu = mtu return } @@ -35,7 +33,12 @@ func (t *SourceSink) Close() error { } func (t *SourceSink) Source() (proxy.Packet, error) { - buf := make([]byte, t.mtu) + mtu, err := t.tun.MTU() + if err != nil { + return nil, err + } + + buf := make([]byte, mtu+4) read, err := t.tun.Read(buf, 4) if err != nil { @@ -46,7 +49,7 @@ func (t *SourceSink) Source() (proxy.Packet, error) { return nil, io.EOF } - return proxy.SimplePacket(buf[4:read]), nil + return proxy.SimplePacket(buf[4 : read+4]), nil } var good, bad float64 diff --git a/udp/listener.go b/udp/listener.go index 0a1e661..fd0b70d 100644 --- a/udp/listener.go +++ b/udp/listener.go @@ -27,7 +27,7 @@ func fromUdpAddress(address net.UDPAddr) ComparableUdpAddress { } } -func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion) error { +func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() proxy.MacVerifier, g func() proxy.MacGenerator, c func() Congestion, enableConsumers bool, enableProducers bool) error { laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return err @@ -80,8 +80,12 @@ func NewListener(ctx context.Context, p *proxy.Proxy, local string, v func() pro receivedConnections[raddr] = &f - p.AddConsumer(ctx, &f, g) - p.AddProducer(ctx, &f, v) + if enableConsumers { + p.AddConsumer(ctx, &f, g) + } + if enableProducers { + p.AddProducer(ctx, &f, v) + } log.Println("handling...") if err := f.queueDatagram(ctx, buf[:n]); err != nil {