pkg/server,private/testplanet: start to listen on quic

This PR introduces a new listener that can listen for quic traffic on
both storagenodes and satellites.

Change-Id: I5eb5bc82c37dde20d3be2ec8fa5f69c18fae0af0
This commit is contained in:
Yingrong Zhao 2021-01-19 11:33:50 -05:00
parent f18cb24522
commit 02845e7b8f
6 changed files with 188 additions and 104 deletions

1
go.sum
View File

@ -147,6 +147,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk=
github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=

View File

@ -174,9 +174,9 @@ type closeTrackingConn struct {
rpc.ConnectorConn
}
// trackClose wraps the conn and sets a finalizer on the returned value to
// TrackClose wraps the conn and sets a finalizer on the returned value to
// close the conn and monitor that it was leaked.
func trackClose(conn rpc.ConnectorConn) rpc.ConnectorConn {
func TrackClose(conn rpc.ConnectorConn) rpc.ConnectorConn {
tracked := &closeTrackingConn{ConnectorConn: conn}
runtime.SetFinalizer(tracked, (*closeTrackingConn).finalize)
return tracked

View File

@ -61,7 +61,7 @@ func (c Connector) DialContext(ctx context.Context, tlsConfig *tls.Config, addre
}
return &timedConn{
ConnectorConn: trackClose(conn),
ConnectorConn: TrackClose(conn),
rate: c.transferRate,
}, nil
}

View File

@ -10,6 +10,8 @@ import (
"github.com/zeebo/errs"
"storj.io/common/netutil"
"storj.io/common/rpc"
"storj.io/storj/pkg/quic"
)
// defaultUserTimeout is the value we use for the TCP_USER_TIMEOUT setting.
@ -19,24 +21,27 @@ const defaultUserTimeout = 60 * time.Second
// and monitors if the returned connections are closed or leaked.
func wrapListener(lis net.Listener) net.Listener {
if lis, ok := lis.(*net.TCPListener); ok {
return newUserTimeoutListener(lis)
return newTCPUserTimeoutListener(lis)
}
if lis, ok := lis.(*quic.Listener); ok {
return newQUICTrackedListener(lis)
}
return lis
}
// userTimeoutListener wraps a tcp listener so that it sets the TCP_USER_TIMEOUT
// tcpUserTimeoutListener wraps a tcp listener so that it sets the TCP_USER_TIMEOUT
// value for each socket it returns.
type userTimeoutListener struct {
type tcpUserTimeoutListener struct {
lis *net.TCPListener
}
// newUserTimeoutListener wraps the tcp listener in a userTimeoutListener.
func newUserTimeoutListener(lis *net.TCPListener) *userTimeoutListener {
return &userTimeoutListener{lis: lis}
// newTCPUserTimeoutListener wraps the tcp listener in a userTimeoutListener.
func newTCPUserTimeoutListener(lis *net.TCPListener) *tcpUserTimeoutListener {
return &tcpUserTimeoutListener{lis: lis}
}
// Accept waits for and returns the next connection to the listener.
func (lis *userTimeoutListener) Accept() (net.Conn, error) {
func (lis *tcpUserTimeoutListener) Accept() (net.Conn, error) {
conn, err := lis.lis.AcceptTCP()
if err != nil {
return nil, err
@ -50,11 +55,44 @@ func (lis *userTimeoutListener) Accept() (net.Conn, error) {
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (lis *userTimeoutListener) Close() error {
func (lis *tcpUserTimeoutListener) Close() error {
return lis.lis.Close()
}
// Addr returns the listener's network address.
func (lis *userTimeoutListener) Addr() net.Addr {
func (lis *tcpUserTimeoutListener) Addr() net.Addr {
return lis.lis.Addr()
}
type quicTrackedListener struct {
lis *quic.Listener
}
func newQUICTrackedListener(lis *quic.Listener) *quicTrackedListener {
return &quicTrackedListener{lis: lis}
}
func (lis *quicTrackedListener) Accept() (net.Conn, error) {
conn, err := lis.lis.Accept()
if err != nil {
return nil, err
}
connectorConn, ok := conn.(rpc.ConnectorConn)
if !ok {
return nil, Error.New("quic connection doesn't implement required methods")
}
return quic.TrackClose(connectorConn), nil
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (lis *quicTrackedListener) Close() error {
return lis.lis.Close()
}
// Addr returns the listener's network address.
func (lis *quicTrackedListener) Addr() net.Addr {
return lis.lis.Addr()
}

View File

@ -9,6 +9,7 @@ import (
"net"
"sync"
quicgo "github.com/lucas-clemente/quic-go"
"github.com/zeebo/errs"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
@ -21,6 +22,7 @@ import (
"storj.io/drpc/drpcserver"
jaeger "storj.io/monkit-jaeger"
"storj.io/storj/pkg/listenmux"
"storj.io/storj/pkg/quic"
)
// Config holds server specific configuration parameters.
@ -33,9 +35,10 @@ type Config struct {
}
type public struct {
listener net.Listener
drpc *drpcserver.Server
mux *drpcmux.Mux
tcpListener net.Listener
quicListener net.Listener
drpc *drpcserver.Server
mux *drpcmux.Mux
}
type private struct {
@ -71,22 +74,28 @@ func New(log *zap.Logger, tlsOptions *tlsopts.Options, publicAddr, privateAddr s
Manager: rpc.NewDefaultManagerOptions(),
}
publicListener, err := net.Listen("tcp", publicAddr)
publicTCPListener, err := net.Listen("tcp", publicAddr)
if err != nil {
return nil, err
}
publicQUICListener, err := quic.NewListener(tlsOptions.ServerTLSConfig(), publicTCPListener.Addr().String(), &quicgo.Config{MaxIdleTimeout: defaultUserTimeout})
if err != nil {
return nil, errs.Combine(err, publicTCPListener.Close())
}
publicMux := drpcmux.New()
publicTracingHandler := rpctracing.NewHandler(publicMux, jaeger.RemoteTraceHandler)
server.public = public{
listener: wrapListener(publicListener),
drpc: drpcserver.NewWithOptions(publicTracingHandler, serverOptions),
mux: publicMux,
tcpListener: wrapListener(publicTCPListener),
quicListener: wrapListener(publicQUICListener),
drpc: drpcserver.NewWithOptions(publicTracingHandler, serverOptions),
mux: publicMux,
}
privateListener, err := net.Listen("tcp", privateAddr)
if err != nil {
return nil, errs.Combine(err, publicListener.Close())
return nil, errs.Combine(err, publicTCPListener.Close(), publicQUICListener.Close())
}
privateMux := drpcmux.New()
privateTracingHandler := rpctracing.NewHandler(privateMux, jaeger.RemoteTraceHandler)
@ -103,7 +112,7 @@ func New(log *zap.Logger, tlsOptions *tlsopts.Options, publicAddr, privateAddr s
func (p *Server) Identity() *identity.FullIdentity { return p.tlsOptions.Ident }
// Addr returns the server's public listener address.
func (p *Server) Addr() net.Addr { return p.public.listener.Addr() }
func (p *Server) Addr() net.Addr { return p.public.tcpListener.Addr() }
// PrivateAddr returns the server's private listener address.
func (p *Server) PrivateAddr() net.Addr { return p.private.listener.Addr() }
@ -127,7 +136,8 @@ func (p *Server) Close() error {
// We ignore these errors because there's not really anything to do
// even if they happen, and they'll just be errors due to duplicate
// closes anyway.
_ = p.public.listener.Close()
_ = p.public.quicListener.Close()
_ = p.public.tcpListener.Close()
_ = p.private.listener.Close()
return nil
}
@ -156,7 +166,7 @@ func (p *Server) Run(ctx context.Context) (err error) {
// a chance to be notified that they're done running.
const drpcHeader = "DRPC!!!1"
publicMux := listenmux.New(p.public.listener, len(drpcHeader))
publicMux := listenmux.New(p.public.tcpListener, len(drpcHeader))
publicDRPCListener := tls.NewListener(publicMux.Route(drpcHeader), p.tlsOptions.ServerTLSConfig())
privateMux := listenmux.New(p.private.listener, len(drpcHeader))
@ -197,6 +207,10 @@ func (p *Server) Run(ctx context.Context) (err error) {
defer cancel()
return p.public.drpc.Serve(ctx, publicDRPCListener)
})
group.Go(func() error {
defer cancel()
return p.public.drpc.Serve(ctx, p.public.quicListener)
})
group.Go(func() error {
defer cancel()
return p.private.drpc.Serve(ctx, privateDRPCListener)

View File

@ -18,6 +18,7 @@ import (
"storj.io/common/rpc"
"storj.io/common/storj"
"storj.io/common/testcontext"
"storj.io/storj/pkg/quic"
"storj.io/storj/private/testplanet"
"storj.io/storj/satellite"
"storj.io/storj/storagenode"
@ -43,82 +44,99 @@ func TestDialNodeURL(t *testing.T) {
}, nil)
require.NoError(t, err)
dialer := rpc.NewDefaultDialer(tlsOptions)
tcpDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer.Connector = quic.NewDefaultConnector(nil)
unsignedClientOpts, err := tlsopts.NewOptions(unsignedIdent, tlsopts.Config{
PeerIDVersions: "*",
}, nil)
require.NoError(t, err)
unsignedDialer := rpc.NewDefaultDialer(unsignedClientOpts)
unsignedTCPDialer := rpc.NewDefaultDialer(unsignedClientOpts)
unsignedQUICDialer := rpc.NewDefaultDialer(unsignedClientOpts)
unsignedQUICDialer.Connector = quic.NewDefaultConnector(nil)
t.Run("DialNodeURL with invalid targets", func(t *testing.T) {
targets := []storj.NodeURL{
{
ID: storj.NodeID{},
Address: "",
},
{
ID: storj.NodeID{123},
Address: "127.0.0.1:100",
},
{
ID: storj.NodeID{},
Address: planet.StorageNodes[1].Addr(),
},
}
test := func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet, dialer rpc.Dialer, unsignedDialer rpc.Dialer) {
t.Run("DialNodeURL with invalid targets", func(t *testing.T) {
targets := []storj.NodeURL{
{
ID: storj.NodeID{},
Address: "",
},
{
ID: storj.NodeID{123},
Address: "127.0.0.1:100",
},
{
ID: storj.NodeID{},
Address: planet.StorageNodes[1].Addr(),
},
}
for _, target := range targets {
tag := fmt.Sprintf("%+v", target)
for _, target := range targets {
tag := fmt.Sprintf("%+v", target)
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, target)
cancel()
assert.Error(t, err, tag)
assert.Nil(t, conn, tag)
}
})
t.Run("DialNode with valid signed target", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, target)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.Error(t, err, tag)
assert.Nil(t, conn, tag)
}
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
t.Run("DialNode with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with valid address", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
}
// test with tcp
t.Run("TCP", func(t *testing.T) {
test(t, ctx, planet, tcpDialer, unsignedTCPDialer)
})
// test with quic
t.Run("QUIC", func(t *testing.T) {
test(t, ctx, planet, quicDialer, unsignedQUICDialer)
})
t.Run("DialNode with valid signed target", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
t.Run("DialNode with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with valid address", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
})
}
@ -150,27 +168,40 @@ func TestDialNode_BadServerCertificate(t *testing.T) {
}, nil)
require.NoError(t, err)
dialer := rpc.NewDefaultDialer(tlsOptions)
tcpDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer.Connector = quic.NewDefaultConnector(nil)
t.Run("DialNodeURL with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
test := func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet, dialer rpc.Dialer) {
t.Run("DialNodeURL with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
tag := fmt.Sprintf("%+v", planet.StorageNodes[1].NodeURL())
assert.Nil(t, conn, tag)
require.Error(t, err, tag)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
tag := fmt.Sprintf("%+v", planet.StorageNodes[1].NodeURL())
assert.Nil(t, conn, tag)
require.Error(t, err, tag)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
t.Run("DialAddress with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
}
// test with tcp
t.Run("TCP", func(t *testing.T) {
test(t, ctx, planet, tcpDialer)
})
t.Run("DialAddress with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
// test with quic
t.Run("QUIC", func(t *testing.T) {
test(t, ctx, planet, quicDialer)
})
})
}