pkg/listenmux: multiplex listener based on first bytes
Change-Id: If96bfd216e55f9950a42ab5be712a3cdec257a10
This commit is contained in:
parent
40ff56f6c7
commit
708c95d044
58
pkg/listenmux/listener.go
Normal file
58
pkg/listenmux/listener.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package listenmux
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
addr net.Addr
|
||||
conns chan net.Conn
|
||||
once sync.Once
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newListener(addr net.Addr) *listener {
|
||||
return &listener{
|
||||
addr: addr,
|
||||
conns: make(chan net.Conn),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Conns returns the channel of net.Conn that the listener reads from.
|
||||
func (l *listener) Conns() chan net.Conn { return l.conns }
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *listener) Accept() (conn net.Conn, err error) {
|
||||
select {
|
||||
case <-l.done:
|
||||
return nil, l.err
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-l.done:
|
||||
return nil, l.err
|
||||
case conn = <-l.conns:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (l *listener) Close() error {
|
||||
l.once.Do(func() {
|
||||
l.err = Closed
|
||||
close(l.done)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
39
pkg/listenmux/listener_test.go
Normal file
39
pkg/listenmux/listener_test.go
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package listenmux
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListener(t *testing.T) {
|
||||
type addr struct{ net.Addr }
|
||||
type conn struct{ net.Conn }
|
||||
|
||||
lis := newListener(addr{})
|
||||
|
||||
{ // ensure the addr is the same we passed in
|
||||
require.Equal(t, lis.Addr(), addr{})
|
||||
}
|
||||
|
||||
{ // ensure that we can accept a connection from the listener
|
||||
go func() { lis.Conns() <- conn{} }()
|
||||
c, err := lis.Accept()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c, conn{})
|
||||
}
|
||||
|
||||
{ // ensure that closing the listener is no problem
|
||||
require.NoError(t, lis.Close())
|
||||
}
|
||||
|
||||
{ // ensure that accept after close returns the right error
|
||||
c, err := lis.Accept()
|
||||
require.Equal(t, err, Closed)
|
||||
require.Nil(t, c)
|
||||
}
|
||||
}
|
167
pkg/listenmux/mux.go
Normal file
167
pkg/listenmux/mux.go
Normal file
@ -0,0 +1,167 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package listenmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
)
|
||||
|
||||
// Closed is returned by routed listeners when the mux is closed.
|
||||
var Closed = errs.New("listener closed")
|
||||
|
||||
// Mux lets one multiplex a listener into different listeners based on the first
|
||||
// bytes sent on the connection.
|
||||
type Mux struct {
|
||||
base net.Listener
|
||||
prefixLen int
|
||||
addr net.Addr
|
||||
def *listener
|
||||
|
||||
mu sync.Mutex
|
||||
routes map[string]*listener
|
||||
|
||||
once sync.Once
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
// New creates a mux that reads the prefixLen bytes from any connections Accepted by the
|
||||
// passed in listener and dispatches to the appropriate route.
|
||||
func New(base net.Listener, prefixLen int) *Mux {
|
||||
addr := base.Addr()
|
||||
return &Mux{
|
||||
base: base,
|
||||
prefixLen: prefixLen,
|
||||
addr: addr,
|
||||
def: newListener(addr),
|
||||
|
||||
routes: make(map[string]*listener),
|
||||
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// set up the routes
|
||||
//
|
||||
|
||||
// Default returns the net.Listener that is used if no route matches.
|
||||
func (m *Mux) Default() net.Listener { return m.def }
|
||||
|
||||
// Route returns a listener that will be used if the first bytes are the given prefix. The
|
||||
// length of the prefix must match the original passed in prefixLen.
|
||||
func (m *Mux) Route(prefix string) net.Listener {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if len(prefix) != m.prefixLen {
|
||||
panic(fmt.Sprintf("invalid prefix: has %d but needs %d bytes", len(prefix), m.prefixLen))
|
||||
}
|
||||
|
||||
lis, ok := m.routes[prefix]
|
||||
if !ok {
|
||||
lis = newListener(m.addr)
|
||||
m.routes[prefix] = lis
|
||||
go m.monitorListener(prefix, lis)
|
||||
}
|
||||
return lis
|
||||
}
|
||||
|
||||
//
|
||||
// run the muxer
|
||||
//
|
||||
|
||||
// Run calls listen on the provided listener and passes connections to the routed
|
||||
// listeners.
|
||||
func (m *Mux) Run(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go m.monitorContext(ctx)
|
||||
go m.monitorBase()
|
||||
|
||||
<-m.done
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, lis := range m.routes {
|
||||
<-lis.done
|
||||
}
|
||||
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *Mux) monitorContext(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
m.once.Do(func() {
|
||||
_ = m.base.Close() // TODO(jeff): do we care about this error?
|
||||
close(m.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mux) monitorBase() {
|
||||
for {
|
||||
conn, err := m.base.Accept()
|
||||
if err != nil {
|
||||
// TODO(jeff): temporary errors?
|
||||
m.once.Do(func() {
|
||||
m.err = err
|
||||
close(m.done)
|
||||
})
|
||||
return
|
||||
}
|
||||
go m.routeConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) monitorListener(prefix string, lis *listener) {
|
||||
select {
|
||||
case <-m.done:
|
||||
lis.once.Do(func() {
|
||||
if m.err != nil {
|
||||
lis.err = m.err
|
||||
} else {
|
||||
lis.err = Closed
|
||||
}
|
||||
close(lis.done)
|
||||
})
|
||||
case <-lis.done:
|
||||
}
|
||||
m.mu.Lock()
|
||||
delete(m.routes, prefix)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Mux) routeConn(conn net.Conn) {
|
||||
buf := make([]byte, m.prefixLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
// TODO(jeff): how to handle these errors?
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
lis, ok := m.routes[string(buf)]
|
||||
if !ok {
|
||||
lis = m.def
|
||||
conn = newPrefixConn(buf, conn)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// TODO(jeff): a timeout for the listener to get to the conn?
|
||||
|
||||
select {
|
||||
case <-lis.done:
|
||||
// TODO(jeff): better way to signal to the caller the listener is closed?
|
||||
_ = conn.Close()
|
||||
case lis.Conns() <- conn:
|
||||
}
|
||||
}
|
104
pkg/listenmux/mux_test.go
Normal file
104
pkg/listenmux/mux_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package listenmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeebo/errs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestMux(t *testing.T) {
|
||||
expect := func(lis net.Listener, data string) func() error {
|
||||
return func() error {
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, len(data))
|
||||
_, err = io.ReadFull(conn, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
require.Equal(t, data, string(buf))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
lis := newFakeListener(
|
||||
newPrefixConn([]byte("prefix1data1"), nil),
|
||||
newPrefixConn([]byte("prefix2data2"), nil),
|
||||
newPrefixConn([]byte("prefix3data3"), nil),
|
||||
)
|
||||
|
||||
mux := New(lis, len("prefixN"))
|
||||
|
||||
var muxGroup errgroup.Group
|
||||
muxGroup.Go(func() error { return mux.Run(ctx) })
|
||||
|
||||
var lisGroup errgroup.Group
|
||||
lisGroup.Go(expect(mux.Route("prefix1"), "data1"))
|
||||
lisGroup.Go(expect(mux.Route("prefix2"), "data2"))
|
||||
lisGroup.Go(expect(mux.Default(), "prefix3data3"))
|
||||
require.NoError(t, lisGroup.Wait())
|
||||
|
||||
cancel()
|
||||
require.Equal(t, nil, muxGroup.Wait())
|
||||
}
|
||||
|
||||
func TestMuxAcceptError(t *testing.T) {
|
||||
err := errs.New("problem")
|
||||
mux := New(newErrorListener(err), 0)
|
||||
require.Equal(t, mux.Run(context.Background()), err)
|
||||
}
|
||||
|
||||
//
|
||||
// fake listener
|
||||
//
|
||||
|
||||
type fakeListener struct {
|
||||
done chan struct{}
|
||||
err error
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
func (fl *fakeListener) Addr() net.Addr { return nil }
|
||||
|
||||
func (fl *fakeListener) Close() error {
|
||||
close(fl.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fl *fakeListener) Accept() (c net.Conn, err error) {
|
||||
if fl.err != nil {
|
||||
return nil, fl.err
|
||||
}
|
||||
if len(fl.conns) == 0 {
|
||||
<-fl.done
|
||||
return nil, Closed
|
||||
}
|
||||
c, fl.conns = fl.conns[0], fl.conns[1:]
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func newFakeListener(conns ...net.Conn) *fakeListener {
|
||||
return &fakeListener{
|
||||
done: make(chan struct{}),
|
||||
conns: conns,
|
||||
}
|
||||
}
|
||||
|
||||
func newErrorListener(err error) *fakeListener {
|
||||
return &fakeListener{err: err}
|
||||
}
|
26
pkg/listenmux/prefixconn.go
Normal file
26
pkg/listenmux/prefixconn.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2019 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package listenmux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func newPrefixConn(data []byte, conn net.Conn) *prefixConn {
|
||||
return &prefixConn{
|
||||
Reader: io.MultiReader(bytes.NewReader(data), conn),
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *prefixConn) Read(p []byte) (n int, err error) {
|
||||
return pc.Reader.Read(p)
|
||||
}
|
Loading…
Reference in New Issue
Block a user