diff --git a/pkg/listenmux/listener.go b/pkg/listenmux/listener.go new file mode 100644 index 000000000..e3c021cb5 --- /dev/null +++ b/pkg/listenmux/listener.go @@ -0,0 +1,58 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package listenmux + +import ( + "net" + "sync" +) + +type listener struct { + addr net.Addr + conns chan net.Conn + once sync.Once + done chan struct{} + err error +} + +func newListener(addr net.Addr) *listener { + return &listener{ + addr: addr, + conns: make(chan net.Conn), + done: make(chan struct{}), + } +} + +// Conns returns the channel of net.Conn that the listener reads from. +func (l *listener) Conns() chan net.Conn { return l.conns } + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (conn net.Conn, err error) { + select { + case <-l.done: + return nil, l.err + default: + } + select { + case <-l.done: + return nil, l.err + case conn = <-l.conns: + return conn, nil + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + l.once.Do(func() { + l.err = Closed + close(l.done) + }) + return nil +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.addr +} diff --git a/pkg/listenmux/listener_test.go b/pkg/listenmux/listener_test.go new file mode 100644 index 000000000..460b2bd13 --- /dev/null +++ b/pkg/listenmux/listener_test.go @@ -0,0 +1,39 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package listenmux + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestListener(t *testing.T) { + type addr struct{ net.Addr } + type conn struct{ net.Conn } + + lis := newListener(addr{}) + + { // ensure the addr is the same we passed in + require.Equal(t, lis.Addr(), addr{}) + } + + { // ensure that we can accept a connection from the listener + go func() { lis.Conns() <- conn{} }() + c, err := lis.Accept() + require.NoError(t, err) + require.Equal(t, c, conn{}) + } + + { // ensure that closing the listener is no problem + require.NoError(t, lis.Close()) + } + + { // ensure that accept after close returns the right error + c, err := lis.Accept() + require.Equal(t, err, Closed) + require.Nil(t, c) + } +} diff --git a/pkg/listenmux/mux.go b/pkg/listenmux/mux.go new file mode 100644 index 000000000..1613eee0e --- /dev/null +++ b/pkg/listenmux/mux.go @@ -0,0 +1,167 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package listenmux + +import ( + "context" + "fmt" + "io" + "net" + "sync" + + "github.com/zeebo/errs" +) + +// Closed is returned by routed listeners when the mux is closed. +var Closed = errs.New("listener closed") + +// Mux lets one multiplex a listener into different listeners based on the first +// bytes sent on the connection. +type Mux struct { + base net.Listener + prefixLen int + addr net.Addr + def *listener + + mu sync.Mutex + routes map[string]*listener + + once sync.Once + done chan struct{} + err error +} + +// New creates a mux that reads the prefixLen bytes from any connections Accepted by the +// passed in listener and dispatches to the appropriate route. +func New(base net.Listener, prefixLen int) *Mux { + addr := base.Addr() + return &Mux{ + base: base, + prefixLen: prefixLen, + addr: addr, + def: newListener(addr), + + routes: make(map[string]*listener), + + done: make(chan struct{}), + } +} + +// +// set up the routes +// + +// Default returns the net.Listener that is used if no route matches. +func (m *Mux) Default() net.Listener { return m.def } + +// Route returns a listener that will be used if the first bytes are the given prefix. The +// length of the prefix must match the original passed in prefixLen. +func (m *Mux) Route(prefix string) net.Listener { + m.mu.Lock() + defer m.mu.Unlock() + + if len(prefix) != m.prefixLen { + panic(fmt.Sprintf("invalid prefix: has %d but needs %d bytes", len(prefix), m.prefixLen)) + } + + lis, ok := m.routes[prefix] + if !ok { + lis = newListener(m.addr) + m.routes[prefix] = lis + go m.monitorListener(prefix, lis) + } + return lis +} + +// +// run the muxer +// + +// Run calls listen on the provided listener and passes connections to the routed +// listeners. +func (m *Mux) Run(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go m.monitorContext(ctx) + go m.monitorBase() + + <-m.done + + m.mu.Lock() + defer m.mu.Unlock() + + for _, lis := range m.routes { + <-lis.done + } + + return m.err +} + +func (m *Mux) monitorContext(ctx context.Context) { + <-ctx.Done() + m.once.Do(func() { + _ = m.base.Close() // TODO(jeff): do we care about this error? + close(m.done) + }) +} + +func (m *Mux) monitorBase() { + for { + conn, err := m.base.Accept() + if err != nil { + // TODO(jeff): temporary errors? + m.once.Do(func() { + m.err = err + close(m.done) + }) + return + } + go m.routeConn(conn) + } +} + +func (m *Mux) monitorListener(prefix string, lis *listener) { + select { + case <-m.done: + lis.once.Do(func() { + if m.err != nil { + lis.err = m.err + } else { + lis.err = Closed + } + close(lis.done) + }) + case <-lis.done: + } + m.mu.Lock() + delete(m.routes, prefix) + m.mu.Unlock() +} + +func (m *Mux) routeConn(conn net.Conn) { + buf := make([]byte, m.prefixLen) + if _, err := io.ReadFull(conn, buf); err != nil { + // TODO(jeff): how to handle these errors? + _ = conn.Close() + return + } + + m.mu.Lock() + lis, ok := m.routes[string(buf)] + if !ok { + lis = m.def + conn = newPrefixConn(buf, conn) + } + m.mu.Unlock() + + // TODO(jeff): a timeout for the listener to get to the conn? + + select { + case <-lis.done: + // TODO(jeff): better way to signal to the caller the listener is closed? + _ = conn.Close() + case lis.Conns() <- conn: + } +} diff --git a/pkg/listenmux/mux_test.go b/pkg/listenmux/mux_test.go new file mode 100644 index 000000000..63e956849 --- /dev/null +++ b/pkg/listenmux/mux_test.go @@ -0,0 +1,104 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package listenmux + +import ( + "context" + "io" + "net" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeebo/errs" + "golang.org/x/sync/errgroup" +) + +func TestMux(t *testing.T) { + expect := func(lis net.Listener, data string) func() error { + return func() error { + conn, err := lis.Accept() + if err != nil { + return err + } + + buf := make([]byte, len(data)) + _, err = io.ReadFull(conn, buf) + if err != nil { + return err + } + + require.Equal(t, data, string(buf)) + return nil + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lis := newFakeListener( + newPrefixConn([]byte("prefix1data1"), nil), + newPrefixConn([]byte("prefix2data2"), nil), + newPrefixConn([]byte("prefix3data3"), nil), + ) + + mux := New(lis, len("prefixN")) + + var muxGroup errgroup.Group + muxGroup.Go(func() error { return mux.Run(ctx) }) + + var lisGroup errgroup.Group + lisGroup.Go(expect(mux.Route("prefix1"), "data1")) + lisGroup.Go(expect(mux.Route("prefix2"), "data2")) + lisGroup.Go(expect(mux.Default(), "prefix3data3")) + require.NoError(t, lisGroup.Wait()) + + cancel() + require.Equal(t, nil, muxGroup.Wait()) +} + +func TestMuxAcceptError(t *testing.T) { + err := errs.New("problem") + mux := New(newErrorListener(err), 0) + require.Equal(t, mux.Run(context.Background()), err) +} + +// +// fake listener +// + +type fakeListener struct { + done chan struct{} + err error + conns []net.Conn +} + +func (fl *fakeListener) Addr() net.Addr { return nil } + +func (fl *fakeListener) Close() error { + close(fl.done) + return nil +} + +func (fl *fakeListener) Accept() (c net.Conn, err error) { + if fl.err != nil { + return nil, fl.err + } + if len(fl.conns) == 0 { + <-fl.done + return nil, Closed + } + c, fl.conns = fl.conns[0], fl.conns[1:] + return c, nil +} + +func newFakeListener(conns ...net.Conn) *fakeListener { + return &fakeListener{ + done: make(chan struct{}), + conns: conns, + } +} + +func newErrorListener(err error) *fakeListener { + return &fakeListener{err: err} +} diff --git a/pkg/listenmux/prefixconn.go b/pkg/listenmux/prefixconn.go new file mode 100644 index 000000000..086cc3a85 --- /dev/null +++ b/pkg/listenmux/prefixconn.go @@ -0,0 +1,26 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package listenmux + +import ( + "bytes" + "io" + "net" +) + +type prefixConn struct { + io.Reader + net.Conn +} + +func newPrefixConn(data []byte, conn net.Conn) *prefixConn { + return &prefixConn{ + Reader: io.MultiReader(bytes.NewReader(data), conn), + Conn: conn, + } +} + +func (pc *prefixConn) Read(p []byte) (n int, err error) { + return pc.Reader.Read(p) +}