dial node/address with bad server cert (#1342)

This commit is contained in:
Bryan White 2019-02-26 19:35:16 +01:00 committed by GitHub
parent cefdff535a
commit fde0020c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 236 additions and 41 deletions

View File

@ -40,7 +40,7 @@ func main() {
dialOption := clientOptions.DialOption(storj.NodeID{})
conn, err := grpc.Dial(*targetAddr, dialOption)
conn, err := grpc.Dial(*targetAddr, dialOption, grpc.WithInsecure())
if err != nil {
panic(err)
}

View File

@ -70,18 +70,6 @@ type Config struct {
Reconfigure Reconfigure
}
// Reconfigure allows to change node configurations
type Reconfigure struct {
NewBootstrapDB func(index int) (bootstrap.DB, error)
Bootstrap func(index int, config *bootstrap.Config)
NewSatelliteDB func(log *zap.Logger, index int) (satellite.DB, error)
Satellite func(log *zap.Logger, index int, config *satellite.Config)
NewStorageNodeDB func(index int) (storagenode.DB, error)
StorageNode func(index int, config *storagenode.Config)
}
// Planet is a full storj system setup.
type Planet struct {
log *zap.Logger

View File

@ -0,0 +1,38 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package testplanet
import (
"go.uber.org/zap"
"storj.io/storj/bootstrap"
"storj.io/storj/satellite"
"storj.io/storj/storagenode"
)
// Reconfigure allows to change node configurations
type Reconfigure struct {
NewBootstrapDB func(index int) (bootstrap.DB, error)
Bootstrap func(index int, config *bootstrap.Config)
NewSatelliteDB func(log *zap.Logger, index int) (satellite.DB, error)
Satellite func(log *zap.Logger, index int, config *satellite.Config)
NewStorageNodeDB func(index int) (storagenode.DB, error)
StorageNode func(index int, config *storagenode.Config)
}
// DisablePeerCAWhitelist returns a `Reconfigure` that sets `UsePeerCAWhitelist` for
// all node types that use kademlia.
var DisablePeerCAWhitelist = Reconfigure{
Bootstrap: func(index int, config *bootstrap.Config) {
config.Server.UsePeerCAWhitelist = false
},
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Server.UsePeerCAWhitelist = false
},
StorageNode: func(index int, config *storagenode.Config) {
config.Server.UsePeerCAWhitelist = false
},
}

View File

@ -43,13 +43,13 @@ func VerifyPeerFunc(next ...PeerCertVerificationFunc) PeerCertVerificationFunc {
return func(chain [][]byte, _ [][]*x509.Certificate) error {
c, err := pkcrypto.CertsFromDER(chain)
if err != nil {
return ErrVerifyPeerCert.Wrap(err)
return NewNonTemporaryError(ErrVerifyPeerCert.Wrap(err))
}
for _, n := range next {
if n != nil {
if err := n(chain, [][]*x509.Certificate{c}); err != nil {
return ErrVerifyPeerCert.Wrap(err)
return NewNonTemporaryError(ErrVerifyPeerCert.Wrap(err))
}
}
}

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"storj.io/storj/internal/testpeertls"
@ -112,8 +113,10 @@ func TestVerifyPeerCertChains(t *testing.T) {
assert.NoError(t, err)
err = peertls.VerifyPeerFunc(peertls.VerifyPeerCertChains)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, peertls.ErrVerifyPeerCert.Has(err))
assert.True(t, peertls.ErrVerifyCertificateChain.Has(err))
nonTempErr, ok := err.(peertls.NonTemporaryError)
require.True(t, ok)
assert.True(t, peertls.ErrVerifyPeerCert.Has(nonTempErr.Err()))
assert.True(t, peertls.ErrVerifyCertificateChain.Has(nonTempErr.Err()))
}
func TestVerifyCAWhitelist(t *testing.T) {
@ -141,7 +144,9 @@ func TestVerifyCAWhitelist(t *testing.T) {
t.Run("no valid signed extension, non-empty whitelist", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{unrelatedCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, peertls.ErrVerifyCAWhitelist.Has(err))
nonTempErr, ok := err.(peertls.NonTemporaryError)
require.True(t, ok)
assert.True(t, peertls.ErrVerifyCAWhitelist.Has(nonTempErr.Err()))
})
t.Run("last cert in whitelist is signer", func(t *testing.T) {

View File

@ -16,7 +16,7 @@ import (
)
// ServerOption returns a grpc `ServerOption` for incoming connections
// to the node with this full identity
// to the node with this full identity.
func (opts *Options) ServerOption() grpc.ServerOption {
pcvFuncs := append(
[]peertls.PeerCertVerificationFunc{
@ -37,9 +37,20 @@ func (opts *Options) ServerOption() grpc.ServerOption {
}
// DialOption returns a grpc `DialOption` for making outgoing connections
// to the node with this peer identity
// id is an optional id of the node we are dialing
// to the node with this peer identity.
// id is an optional id of the node we are dialing.
func (opts *Options) DialOption(id storj.NodeID) grpc.DialOption {
return grpc.WithTransportCredentials(opts.TransportCredentials(id))
}
// TransportCredentials returns a grpc `credentials.TransportCredentials`
// implementation for use within peertls.
func (opts *Options) TransportCredentials(id storj.NodeID) credentials.TransportCredentials {
return credentials.NewTLS(opts.TLSConfig(id))
}
// TLSConfig returns a TSLConfig for use in handshaking with a peer.
func (opts *Options) TLSConfig(id storj.NodeID) *tls.Config {
pcvFuncs := append(
[]peertls.PeerCertVerificationFunc{
peertls.VerifyPeerCertChains,
@ -47,15 +58,13 @@ func (opts *Options) DialOption(id storj.NodeID) grpc.DialOption {
},
opts.PCVFuncs...,
)
tlsConfig := &tls.Config{
return &tls.Config{
Certificates: []tls.Certificate{*opts.Cert},
InsecureSkipVerify: true,
VerifyPeerCertificate: peertls.VerifyPeerFunc(
pcvFuncs...,
),
}
return grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}
func verifyIdentity(id storj.NodeID) peertls.PeerCertVerificationFunc {

View File

@ -22,6 +22,30 @@ import (
"storj.io/storj/pkg/pkcrypto"
)
// NonTemporaryError is an error with a `Temporary` method which always returns false.
// It is intended for use with grpc.
//
// (see https://godoc.org/google.golang.org/grpc#WithDialer
// and https://godoc.org/google.golang.org/grpc#FailOnNonTempDialError).
type NonTemporaryError struct{ error }
// NewNonTemporaryError returns a new temporary error for use with grpc.
func NewNonTemporaryError(err error) NonTemporaryError {
return NonTemporaryError{
error: errs.Wrap(err),
}
}
// Temporary returns false to indicate that is is a non-temporary error
func (nte NonTemporaryError) Temporary() bool {
return false
}
// Err returns the underlying error
func (nte NonTemporaryError) Err() error {
return nte.error
}
func verifyChainSignatures(certs []*x509.Certificate) error {
for i, cert := range certs {
j := len(certs)

View File

@ -64,9 +64,11 @@ func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ..
return nil, Error.New("no address")
}
// add ID of node we are wanting to connect to
dialOpt := transport.tlsOpts.DialOption(node.Id)
options := append([]grpc.DialOption{dialOpt, grpc.WithBlock(), grpc.FailOnNonTempDialError(true)}, opts...)
options := append([]grpc.DialOption{
transport.tlsOpts.DialOption(node.Id),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
}, opts...)
ctx, cf := context.WithTimeout(ctx, timeout)
defer cf()
@ -89,8 +91,11 @@ func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ..
func (transport *Transport) DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
defer mon.Task()(&ctx)(&err)
dialOpt := transport.tlsOpts.DialOption(storj.NodeID{})
options := append([]grpc.DialOption{dialOpt, grpc.WithBlock(), grpc.FailOnNonTempDialError(true)}, opts...)
options := append([]grpc.DialOption{
transport.tlsOpts.DialOption(storj.NodeID{}),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
}, opts...)
conn, err = grpc.DialContext(ctx, address, options...)
if err == context.Canceled {

View File

@ -10,10 +10,13 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"storj.io/storj/internal/testcontext"
"storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
)
@ -27,11 +30,29 @@ func TestDialNode(t *testing.T) {
}
defer ctx.Check(planet.Shutdown)
whitelistPath, err := planet.WriteWhitelist()
require.NoError(t, err)
planet.Start(ctx)
client := planet.StorageNodes[0].Transport
{ // DialNode with invalid targets
unsignedIdent, err := testplanet.PregeneratedIdentity(0)
require.NoError(t, err)
signedIdent, err := testplanet.PregeneratedSignedIdentity(0)
require.NoError(t, err)
opts, err := tlsopts.NewOptions(signedIdent, tlsopts.Config{
UsePeerCAWhitelist: true,
PeerCAWhitelistPath: whitelistPath,
})
require.NoError(t, err)
unsignedClientOpts, err := tlsopts.NewOptions(unsignedIdent, tlsopts.Config{})
require.NoError(t, err)
t.Run("DialNode with invalid targets", func(t *testing.T) {
targets := []*pb.Node{
{
Id: storj.NodeID{},
@ -64,34 +85,139 @@ func TestDialNode(t *testing.T) {
assert.Error(t, err, tag)
assert.Nil(t, conn, tag)
}
}
})
{ // DialNode with valid target
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := client.DialNode(timedCtx, &pb.Node{
t.Run("DialNode with valid target", func(t *testing.T) {
target := &pb.Node{
Id: planet.StorageNodes[1].ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[1].Addr(),
},
Type: pb.NodeType_STORAGE,
})
}
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{})
conn, err := client.DialNode(timedCtx, target, dialOption)
cancel()
assert.NoError(t, err)
assert.NotNil(t, conn)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
}
})
{ // DialAddress with valid address
t.Run("DialNode with bad client certificate", func(t *testing.T) {
target := &pb.Node{
Id: planet.StorageNodes[1].ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[1].Addr(),
},
Type: pb.NodeType_STORAGE,
}
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := unsignedClientOpts.DialOption(storj.NodeID{})
conn, err := client.DialNode(
timedCtx, target, dialOption,
)
cancel()
tag := fmt.Sprintf("%+v", target)
assert.Nil(t, conn, tag)
require.Error(t, err)
assert.Contains(t, err.Error(), "bad certificate")
})
t.Run("DialAddress with bad client certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := unsignedClientOpts.DialOption(storj.NodeID{})
conn, err := client.DialAddress(
timedCtx, planet.StorageNodes[1].Addr(), dialOption,
)
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "bad certificate")
})
t.Run("DialAddress with valid address", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr())
cancel()
assert.NoError(t, err)
assert.NotNil(t, conn)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
}
})
}
func TestDialNode_BadServerCertificate(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
planet, err := testplanet.NewCustom(
zap.L(),
testplanet.Config{
SatelliteCount: 0,
StorageNodeCount: 2,
UplinkCount: 0,
Reconfigure: testplanet.DisablePeerCAWhitelist,
Identities: testplanet.NewPregeneratedIdentities(),
},
)
if err != nil {
t.Fatal(err)
}
defer ctx.Check(planet.Shutdown)
whitelistPath, err := planet.WriteWhitelist()
require.NoError(t, err)
planet.Start(ctx)
client := planet.StorageNodes[0].Transport
ident, err := testplanet.PregeneratedSignedIdentity(0)
require.NoError(t, err)
opts, err := tlsopts.NewOptions(ident, tlsopts.Config{
UsePeerCAWhitelist: true,
PeerCAWhitelistPath: whitelistPath,
})
require.NoError(t, err)
t.Run("DialNode with bad server certificate", func(t *testing.T) {
target := &pb.Node{
Id: planet.StorageNodes[1].ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[1].Addr(),
},
Type: pb.NodeType_STORAGE,
}
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{})
conn, err := client.DialNode(timedCtx, target, dialOption)
cancel()
tag := fmt.Sprintf("%+v", target)
assert.Nil(t, conn, tag)
require.Error(t, err, tag)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
t.Run("DialAddress with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{})
conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr(), dialOption)
cancel()
assert.Nil(t, conn)
require.Error(t, err)
assert.Contains(t, err.Error(), "not signed by any CA in the whitelist")
})
}