pkg/listenmux: multiplex listener based on first bytes

Change-Id: If96bfd216e55f9950a42ab5be712a3cdec257a10
This commit is contained in:
Jeff Wendling 2019-09-06 17:59:55 -06:00
parent 40ff56f6c7
commit 708c95d044
5 changed files with 394 additions and 0 deletions

58
pkg/listenmux/listener.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"net"
"sync"
)
type listener struct {
addr net.Addr
conns chan net.Conn
once sync.Once
done chan struct{}
err error
}
func newListener(addr net.Addr) *listener {
return &listener{
addr: addr,
conns: make(chan net.Conn),
done: make(chan struct{}),
}
}
// Conns returns the channel of net.Conn that the listener reads from.
func (l *listener) Conns() chan net.Conn { return l.conns }
// Accept waits for and returns the next connection to the listener.
func (l *listener) Accept() (conn net.Conn, err error) {
select {
case <-l.done:
return nil, l.err
default:
}
select {
case <-l.done:
return nil, l.err
case conn = <-l.conns:
return conn, nil
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *listener) Close() error {
l.once.Do(func() {
l.err = Closed
close(l.done)
})
return nil
}
// Addr returns the listener's network address.
func (l *listener) Addr() net.Addr {
return l.addr
}

View File

@ -0,0 +1,39 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestListener(t *testing.T) {
type addr struct{ net.Addr }
type conn struct{ net.Conn }
lis := newListener(addr{})
{ // ensure the addr is the same we passed in
require.Equal(t, lis.Addr(), addr{})
}
{ // ensure that we can accept a connection from the listener
go func() { lis.Conns() <- conn{} }()
c, err := lis.Accept()
require.NoError(t, err)
require.Equal(t, c, conn{})
}
{ // ensure that closing the listener is no problem
require.NoError(t, lis.Close())
}
{ // ensure that accept after close returns the right error
c, err := lis.Accept()
require.Equal(t, err, Closed)
require.Nil(t, c)
}
}

167
pkg/listenmux/mux.go Normal file
View File

@ -0,0 +1,167 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"context"
"fmt"
"io"
"net"
"sync"
"github.com/zeebo/errs"
)
// Closed is returned by routed listeners when the mux is closed.
var Closed = errs.New("listener closed")
// Mux lets one multiplex a listener into different listeners based on the first
// bytes sent on the connection.
type Mux struct {
base net.Listener
prefixLen int
addr net.Addr
def *listener
mu sync.Mutex
routes map[string]*listener
once sync.Once
done chan struct{}
err error
}
// New creates a mux that reads the prefixLen bytes from any connections Accepted by the
// passed in listener and dispatches to the appropriate route.
func New(base net.Listener, prefixLen int) *Mux {
addr := base.Addr()
return &Mux{
base: base,
prefixLen: prefixLen,
addr: addr,
def: newListener(addr),
routes: make(map[string]*listener),
done: make(chan struct{}),
}
}
//
// set up the routes
//
// Default returns the net.Listener that is used if no route matches.
func (m *Mux) Default() net.Listener { return m.def }
// Route returns a listener that will be used if the first bytes are the given prefix. The
// length of the prefix must match the original passed in prefixLen.
func (m *Mux) Route(prefix string) net.Listener {
m.mu.Lock()
defer m.mu.Unlock()
if len(prefix) != m.prefixLen {
panic(fmt.Sprintf("invalid prefix: has %d but needs %d bytes", len(prefix), m.prefixLen))
}
lis, ok := m.routes[prefix]
if !ok {
lis = newListener(m.addr)
m.routes[prefix] = lis
go m.monitorListener(prefix, lis)
}
return lis
}
//
// run the muxer
//
// Run calls listen on the provided listener and passes connections to the routed
// listeners.
func (m *Mux) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go m.monitorContext(ctx)
go m.monitorBase()
<-m.done
m.mu.Lock()
defer m.mu.Unlock()
for _, lis := range m.routes {
<-lis.done
}
return m.err
}
func (m *Mux) monitorContext(ctx context.Context) {
<-ctx.Done()
m.once.Do(func() {
_ = m.base.Close() // TODO(jeff): do we care about this error?
close(m.done)
})
}
func (m *Mux) monitorBase() {
for {
conn, err := m.base.Accept()
if err != nil {
// TODO(jeff): temporary errors?
m.once.Do(func() {
m.err = err
close(m.done)
})
return
}
go m.routeConn(conn)
}
}
func (m *Mux) monitorListener(prefix string, lis *listener) {
select {
case <-m.done:
lis.once.Do(func() {
if m.err != nil {
lis.err = m.err
} else {
lis.err = Closed
}
close(lis.done)
})
case <-lis.done:
}
m.mu.Lock()
delete(m.routes, prefix)
m.mu.Unlock()
}
func (m *Mux) routeConn(conn net.Conn) {
buf := make([]byte, m.prefixLen)
if _, err := io.ReadFull(conn, buf); err != nil {
// TODO(jeff): how to handle these errors?
_ = conn.Close()
return
}
m.mu.Lock()
lis, ok := m.routes[string(buf)]
if !ok {
lis = m.def
conn = newPrefixConn(buf, conn)
}
m.mu.Unlock()
// TODO(jeff): a timeout for the listener to get to the conn?
select {
case <-lis.done:
// TODO(jeff): better way to signal to the caller the listener is closed?
_ = conn.Close()
case lis.Conns() <- conn:
}
}

104
pkg/listenmux/mux_test.go Normal file
View File

@ -0,0 +1,104 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"context"
"io"
"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"golang.org/x/sync/errgroup"
)
func TestMux(t *testing.T) {
expect := func(lis net.Listener, data string) func() error {
return func() error {
conn, err := lis.Accept()
if err != nil {
return err
}
buf := make([]byte, len(data))
_, err = io.ReadFull(conn, buf)
if err != nil {
return err
}
require.Equal(t, data, string(buf))
return nil
}
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lis := newFakeListener(
newPrefixConn([]byte("prefix1data1"), nil),
newPrefixConn([]byte("prefix2data2"), nil),
newPrefixConn([]byte("prefix3data3"), nil),
)
mux := New(lis, len("prefixN"))
var muxGroup errgroup.Group
muxGroup.Go(func() error { return mux.Run(ctx) })
var lisGroup errgroup.Group
lisGroup.Go(expect(mux.Route("prefix1"), "data1"))
lisGroup.Go(expect(mux.Route("prefix2"), "data2"))
lisGroup.Go(expect(mux.Default(), "prefix3data3"))
require.NoError(t, lisGroup.Wait())
cancel()
require.Equal(t, nil, muxGroup.Wait())
}
func TestMuxAcceptError(t *testing.T) {
err := errs.New("problem")
mux := New(newErrorListener(err), 0)
require.Equal(t, mux.Run(context.Background()), err)
}
//
// fake listener
//
type fakeListener struct {
done chan struct{}
err error
conns []net.Conn
}
func (fl *fakeListener) Addr() net.Addr { return nil }
func (fl *fakeListener) Close() error {
close(fl.done)
return nil
}
func (fl *fakeListener) Accept() (c net.Conn, err error) {
if fl.err != nil {
return nil, fl.err
}
if len(fl.conns) == 0 {
<-fl.done
return nil, Closed
}
c, fl.conns = fl.conns[0], fl.conns[1:]
return c, nil
}
func newFakeListener(conns ...net.Conn) *fakeListener {
return &fakeListener{
done: make(chan struct{}),
conns: conns,
}
}
func newErrorListener(err error) *fakeListener {
return &fakeListener{err: err}
}

View File

@ -0,0 +1,26 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"bytes"
"io"
"net"
)
type prefixConn struct {
io.Reader
net.Conn
}
func newPrefixConn(data []byte, conn net.Conn) *prefixConn {
return &prefixConn{
Reader: io.MultiReader(bytes.NewReader(data), conn),
Conn: conn,
}
}
func (pc *prefixConn) Read(p []byte) (n int, err error) {
return pc.Reader.Read(p)
}