pkg/server,private/testplanet: start to listen on quic

This PR introduces a new listener that can listen for quic traffic on
both storagenodes and satellites.

Change-Id: I5eb5bc82c37dde20d3be2ec8fa5f69c18fae0af0
This commit is contained in:
Yingrong Zhao 2021-01-19 11:33:50 -05:00
parent f18cb24522
commit 02845e7b8f
6 changed files with 188 additions and 104 deletions

1
go.sum
View File

@ -147,6 +147,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk=
github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=

View File

@ -174,9 +174,9 @@ type closeTrackingConn struct {
rpc.ConnectorConn rpc.ConnectorConn
} }
// trackClose wraps the conn and sets a finalizer on the returned value to // TrackClose wraps the conn and sets a finalizer on the returned value to
// close the conn and monitor that it was leaked. // close the conn and monitor that it was leaked.
func trackClose(conn rpc.ConnectorConn) rpc.ConnectorConn { func TrackClose(conn rpc.ConnectorConn) rpc.ConnectorConn {
tracked := &closeTrackingConn{ConnectorConn: conn} tracked := &closeTrackingConn{ConnectorConn: conn}
runtime.SetFinalizer(tracked, (*closeTrackingConn).finalize) runtime.SetFinalizer(tracked, (*closeTrackingConn).finalize)
return tracked return tracked

View File

@ -61,7 +61,7 @@ func (c Connector) DialContext(ctx context.Context, tlsConfig *tls.Config, addre
} }
return &timedConn{ return &timedConn{
ConnectorConn: trackClose(conn), ConnectorConn: TrackClose(conn),
rate: c.transferRate, rate: c.transferRate,
}, nil }, nil
} }

View File

@ -10,6 +10,8 @@ import (
"github.com/zeebo/errs" "github.com/zeebo/errs"
"storj.io/common/netutil" "storj.io/common/netutil"
"storj.io/common/rpc"
"storj.io/storj/pkg/quic"
) )
// defaultUserTimeout is the value we use for the TCP_USER_TIMEOUT setting. // defaultUserTimeout is the value we use for the TCP_USER_TIMEOUT setting.
@ -19,24 +21,27 @@ const defaultUserTimeout = 60 * time.Second
// and monitors if the returned connections are closed or leaked. // and monitors if the returned connections are closed or leaked.
func wrapListener(lis net.Listener) net.Listener { func wrapListener(lis net.Listener) net.Listener {
if lis, ok := lis.(*net.TCPListener); ok { if lis, ok := lis.(*net.TCPListener); ok {
return newUserTimeoutListener(lis) return newTCPUserTimeoutListener(lis)
}
if lis, ok := lis.(*quic.Listener); ok {
return newQUICTrackedListener(lis)
} }
return lis return lis
} }
// userTimeoutListener wraps a tcp listener so that it sets the TCP_USER_TIMEOUT // tcpUserTimeoutListener wraps a tcp listener so that it sets the TCP_USER_TIMEOUT
// value for each socket it returns. // value for each socket it returns.
type userTimeoutListener struct { type tcpUserTimeoutListener struct {
lis *net.TCPListener lis *net.TCPListener
} }
// newUserTimeoutListener wraps the tcp listener in a userTimeoutListener. // newTCPUserTimeoutListener wraps the tcp listener in a userTimeoutListener.
func newUserTimeoutListener(lis *net.TCPListener) *userTimeoutListener { func newTCPUserTimeoutListener(lis *net.TCPListener) *tcpUserTimeoutListener {
return &userTimeoutListener{lis: lis} return &tcpUserTimeoutListener{lis: lis}
} }
// Accept waits for and returns the next connection to the listener. // Accept waits for and returns the next connection to the listener.
func (lis *userTimeoutListener) Accept() (net.Conn, error) { func (lis *tcpUserTimeoutListener) Accept() (net.Conn, error) {
conn, err := lis.lis.AcceptTCP() conn, err := lis.lis.AcceptTCP()
if err != nil { if err != nil {
return nil, err return nil, err
@ -50,11 +55,44 @@ func (lis *userTimeoutListener) Accept() (net.Conn, error) {
// Close closes the listener. // Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors. // Any blocked Accept operations will be unblocked and return errors.
func (lis *userTimeoutListener) Close() error { func (lis *tcpUserTimeoutListener) Close() error {
return lis.lis.Close() return lis.lis.Close()
} }
// Addr returns the listener's network address. // Addr returns the listener's network address.
func (lis *userTimeoutListener) Addr() net.Addr { func (lis *tcpUserTimeoutListener) Addr() net.Addr {
return lis.lis.Addr()
}
type quicTrackedListener struct {
lis *quic.Listener
}
func newQUICTrackedListener(lis *quic.Listener) *quicTrackedListener {
return &quicTrackedListener{lis: lis}
}
func (lis *quicTrackedListener) Accept() (net.Conn, error) {
conn, err := lis.lis.Accept()
if err != nil {
return nil, err
}
connectorConn, ok := conn.(rpc.ConnectorConn)
if !ok {
return nil, Error.New("quic connection doesn't implement required methods")
}
return quic.TrackClose(connectorConn), nil
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (lis *quicTrackedListener) Close() error {
return lis.lis.Close()
}
// Addr returns the listener's network address.
func (lis *quicTrackedListener) Addr() net.Addr {
return lis.lis.Addr() return lis.lis.Addr()
} }

View File

@ -9,6 +9,7 @@ import (
"net" "net"
"sync" "sync"
quicgo "github.com/lucas-clemente/quic-go"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -21,6 +22,7 @@ import (
"storj.io/drpc/drpcserver" "storj.io/drpc/drpcserver"
jaeger "storj.io/monkit-jaeger" jaeger "storj.io/monkit-jaeger"
"storj.io/storj/pkg/listenmux" "storj.io/storj/pkg/listenmux"
"storj.io/storj/pkg/quic"
) )
// Config holds server specific configuration parameters. // Config holds server specific configuration parameters.
@ -33,9 +35,10 @@ type Config struct {
} }
type public struct { type public struct {
listener net.Listener tcpListener net.Listener
drpc *drpcserver.Server quicListener net.Listener
mux *drpcmux.Mux drpc *drpcserver.Server
mux *drpcmux.Mux
} }
type private struct { type private struct {
@ -71,22 +74,28 @@ func New(log *zap.Logger, tlsOptions *tlsopts.Options, publicAddr, privateAddr s
Manager: rpc.NewDefaultManagerOptions(), Manager: rpc.NewDefaultManagerOptions(),
} }
publicListener, err := net.Listen("tcp", publicAddr) publicTCPListener, err := net.Listen("tcp", publicAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
publicQUICListener, err := quic.NewListener(tlsOptions.ServerTLSConfig(), publicTCPListener.Addr().String(), &quicgo.Config{MaxIdleTimeout: defaultUserTimeout})
if err != nil {
return nil, errs.Combine(err, publicTCPListener.Close())
}
publicMux := drpcmux.New() publicMux := drpcmux.New()
publicTracingHandler := rpctracing.NewHandler(publicMux, jaeger.RemoteTraceHandler) publicTracingHandler := rpctracing.NewHandler(publicMux, jaeger.RemoteTraceHandler)
server.public = public{ server.public = public{
listener: wrapListener(publicListener), tcpListener: wrapListener(publicTCPListener),
drpc: drpcserver.NewWithOptions(publicTracingHandler, serverOptions), quicListener: wrapListener(publicQUICListener),
mux: publicMux, drpc: drpcserver.NewWithOptions(publicTracingHandler, serverOptions),
mux: publicMux,
} }
privateListener, err := net.Listen("tcp", privateAddr) privateListener, err := net.Listen("tcp", privateAddr)
if err != nil { if err != nil {
return nil, errs.Combine(err, publicListener.Close()) return nil, errs.Combine(err, publicTCPListener.Close(), publicQUICListener.Close())
} }
privateMux := drpcmux.New() privateMux := drpcmux.New()
privateTracingHandler := rpctracing.NewHandler(privateMux, jaeger.RemoteTraceHandler) privateTracingHandler := rpctracing.NewHandler(privateMux, jaeger.RemoteTraceHandler)
@ -103,7 +112,7 @@ func New(log *zap.Logger, tlsOptions *tlsopts.Options, publicAddr, privateAddr s
func (p *Server) Identity() *identity.FullIdentity { return p.tlsOptions.Ident } func (p *Server) Identity() *identity.FullIdentity { return p.tlsOptions.Ident }
// Addr returns the server's public listener address. // Addr returns the server's public listener address.
func (p *Server) Addr() net.Addr { return p.public.listener.Addr() } func (p *Server) Addr() net.Addr { return p.public.tcpListener.Addr() }
// PrivateAddr returns the server's private listener address. // PrivateAddr returns the server's private listener address.
func (p *Server) PrivateAddr() net.Addr { return p.private.listener.Addr() } func (p *Server) PrivateAddr() net.Addr { return p.private.listener.Addr() }
@ -127,7 +136,8 @@ func (p *Server) Close() error {
// We ignore these errors because there's not really anything to do // We ignore these errors because there's not really anything to do
// even if they happen, and they'll just be errors due to duplicate // even if they happen, and they'll just be errors due to duplicate
// closes anyway. // closes anyway.
_ = p.public.listener.Close() _ = p.public.quicListener.Close()
_ = p.public.tcpListener.Close()
_ = p.private.listener.Close() _ = p.private.listener.Close()
return nil return nil
} }
@ -156,7 +166,7 @@ func (p *Server) Run(ctx context.Context) (err error) {
// a chance to be notified that they're done running. // a chance to be notified that they're done running.
const drpcHeader = "DRPC!!!1" const drpcHeader = "DRPC!!!1"
publicMux := listenmux.New(p.public.listener, len(drpcHeader)) publicMux := listenmux.New(p.public.tcpListener, len(drpcHeader))
publicDRPCListener := tls.NewListener(publicMux.Route(drpcHeader), p.tlsOptions.ServerTLSConfig()) publicDRPCListener := tls.NewListener(publicMux.Route(drpcHeader), p.tlsOptions.ServerTLSConfig())
privateMux := listenmux.New(p.private.listener, len(drpcHeader)) privateMux := listenmux.New(p.private.listener, len(drpcHeader))
@ -197,6 +207,10 @@ func (p *Server) Run(ctx context.Context) (err error) {
defer cancel() defer cancel()
return p.public.drpc.Serve(ctx, publicDRPCListener) return p.public.drpc.Serve(ctx, publicDRPCListener)
}) })
group.Go(func() error {
defer cancel()
return p.public.drpc.Serve(ctx, p.public.quicListener)
})
group.Go(func() error { group.Go(func() error {
defer cancel() defer cancel()
return p.private.drpc.Serve(ctx, privateDRPCListener) return p.private.drpc.Serve(ctx, privateDRPCListener)

View File

@ -18,6 +18,7 @@ import (
"storj.io/common/rpc" "storj.io/common/rpc"
"storj.io/common/storj" "storj.io/common/storj"
"storj.io/common/testcontext" "storj.io/common/testcontext"
"storj.io/storj/pkg/quic"
"storj.io/storj/private/testplanet" "storj.io/storj/private/testplanet"
"storj.io/storj/satellite" "storj.io/storj/satellite"
"storj.io/storj/storagenode" "storj.io/storj/storagenode"
@ -43,82 +44,99 @@ func TestDialNodeURL(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
dialer := rpc.NewDefaultDialer(tlsOptions) tcpDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer.Connector = quic.NewDefaultConnector(nil)
unsignedClientOpts, err := tlsopts.NewOptions(unsignedIdent, tlsopts.Config{ unsignedClientOpts, err := tlsopts.NewOptions(unsignedIdent, tlsopts.Config{
PeerIDVersions: "*", PeerIDVersions: "*",
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
unsignedDialer := rpc.NewDefaultDialer(unsignedClientOpts) unsignedTCPDialer := rpc.NewDefaultDialer(unsignedClientOpts)
unsignedQUICDialer := rpc.NewDefaultDialer(unsignedClientOpts)
unsignedQUICDialer.Connector = quic.NewDefaultConnector(nil)
t.Run("DialNodeURL with invalid targets", func(t *testing.T) { test := func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet, dialer rpc.Dialer, unsignedDialer rpc.Dialer) {
targets := []storj.NodeURL{ t.Run("DialNodeURL with invalid targets", func(t *testing.T) {
{ targets := []storj.NodeURL{
ID: storj.NodeID{}, {
Address: "", ID: storj.NodeID{},
}, Address: "",
{ },
ID: storj.NodeID{123}, {
Address: "127.0.0.1:100", ID: storj.NodeID{123},
}, Address: "127.0.0.1:100",
{ },
ID: storj.NodeID{}, {
Address: planet.StorageNodes[1].Addr(), ID: storj.NodeID{},
}, Address: planet.StorageNodes[1].Addr(),
} },
}
for _, target := range targets { for _, target := range targets {
tag := fmt.Sprintf("%+v", target) tag := fmt.Sprintf("%+v", target)
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, target)
cancel()
assert.Error(t, err, tag)
assert.Nil(t, conn, tag)
}
})
t.Run("DialNode with valid signed target", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, target) conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel() cancel()
assert.Error(t, err, tag)
assert.Nil(t, conn, tag) assert.NoError(t, err)
} require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
t.Run("DialNode with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with valid address", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
}
// test with tcp
t.Run("TCP", func(t *testing.T) {
test(t, ctx, planet, tcpDialer, unsignedTCPDialer)
})
// test with quic
t.Run("QUIC", func(t *testing.T) {
test(t, ctx, planet, quicDialer, unsignedQUICDialer)
}) })
t.Run("DialNode with valid signed target", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
t.Run("DialNode with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with unsigned identity", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := unsignedDialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NotNil(t, conn)
require.NoError(t, err)
assert.NoError(t, conn.Close())
})
t.Run("DialAddress with valid address", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
}) })
} }
@ -150,27 +168,40 @@ func TestDialNode_BadServerCertificate(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
dialer := rpc.NewDefaultDialer(tlsOptions) tcpDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer := rpc.NewDefaultDialer(tlsOptions)
quicDialer.Connector = quic.NewDefaultConnector(nil)
t.Run("DialNodeURL with bad server certificate", func(t *testing.T) { test := func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet, dialer rpc.Dialer) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second) t.Run("DialNodeURL with bad server certificate", func(t *testing.T) {
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL()) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
cancel() conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
tag := fmt.Sprintf("%+v", planet.StorageNodes[1].NodeURL()) tag := fmt.Sprintf("%+v", planet.StorageNodes[1].NodeURL())
assert.Nil(t, conn, tag) assert.Nil(t, conn, tag)
require.Error(t, err, tag) require.Error(t, err, tag)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist") assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
t.Run("DialAddress with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
}
// test with tcp
t.Run("TCP", func(t *testing.T) {
test(t, ctx, planet, tcpDialer)
}) })
// test with quic
t.Run("DialAddress with bad server certificate", func(t *testing.T) { t.Run("QUIC", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second) test(t, ctx, planet, quicDialer)
conn, err := dialer.DialNodeURL(timedCtx, planet.StorageNodes[1].NodeURL())
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
}) })
}) })
} }