[V3-1320] fix empty node ID verification non-error (#1395)

* small identity refactor:

+ Optimize? iterative cert chain methods to use array instead of slice
+ Add `ToChain` helper for converting 1d to 2d cert chain
  TODO: replace literal declarations with this
+ rename `ChainRaw/RestChainRaw` to `RawChain/RawRestChain`
  (adjective noun, instead of nound adjective)

* add regression tests for V3-1320

* fix V3-1320

* separate `DialUnverifiedIDOption` from `DialOption`

* separate `PingNode` and `DialNode` from `PingAddress` and `DialAddress`

* update node ID while bootstrapping

* goimports & fix comment

* add test case
This commit is contained in:
Bryan White 2019-03-04 21:03:33 +01:00 committed by Dennis Coyle
parent 588e2a51d2
commit 675e0ef683
22 changed files with 265 additions and 61 deletions

View File

@ -99,13 +99,8 @@ func checkCAChain(opts checkOpts, errFmt string) {
} }
func checkIdentContainsCA(opts checkOpts, errFmt string) { func checkIdentContainsCA(opts checkOpts, errFmt string) {
identChainBytes := append([][]byte{ identChainBytes := opts.identity.RawChain()
opts.identity.Leaf.Raw, caChainBytes := opts.ca.RawChain()
opts.identity.CA.Raw,
}, opts.ca.RestChainRaw()...)
caChainBytes := append([][]byte{
opts.ca.Cert.Raw,
}, opts.ca.RestChainRaw()...)
for i, caCert := range caChainBytes { for i, caCert := range caChainBytes {
j := i + 1 j := i + 1

View File

@ -13,7 +13,6 @@ import (
"storj.io/storj/pkg/cfgstruct" "storj.io/storj/pkg/cfgstruct"
"storj.io/storj/pkg/identity" "storj.io/storj/pkg/identity"
"storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
) )
var ( var (
@ -38,7 +37,7 @@ func main() {
panic(err) panic(err)
} }
dialOption := clientOptions.DialOption(storj.NodeID{}) dialOption := clientOptions.DialUnverifiedIDOption()
conn, err := grpc.Dial(*targetAddr, dialOption, grpc.WithInsecure()) conn, err := grpc.Dial(*targetAddr, dialOption, grpc.WithInsecure())
if err != nil { if err != nil {

View File

@ -63,7 +63,7 @@ func SignMessage(msg SignableMessage, ID identity.FullIdentity) error {
return ErrSign.Wrap(err) return ErrSign.Wrap(err)
} }
msg.SetSignature(signature) msg.SetSignature(signature)
msg.SetCerts(ID.ChainRaw()) msg.SetCerts(ID.RawChain())
return nil return nil
} }

View File

@ -195,7 +195,7 @@ func testDatabase(ctx context.Context, t *testing.T, db satellite.DB) {
// Generate a new keypair for self signing bwagreements // Generate a new keypair for self signing bwagreements
manipID, err := testidentity.NewTestIdentity(ctx) manipID, err := testidentity.NewTestIdentity(ctx)
assert.NoError(t, err) assert.NoError(t, err)
manipCerts := manipID.ChainRaw() manipCerts := manipID.RawChain()
manipPrivKey := manipID.Key manipPrivKey := manipID.Key
/* Storage node can't manipulate the bwagreement size (or any other field) /* Storage node can't manipulate the bwagreement size (or any other field)

View File

@ -227,7 +227,7 @@ func (c CertificateSigner) Sign(ctx context.Context, req *pb.SigningRequest) (*p
} }
signedChainBytes := [][]byte{signedPeerCA.Raw, c.signer.Cert.Raw} signedChainBytes := [][]byte{signedPeerCA.Raw, c.signer.Cert.Raw}
signedChainBytes = append(signedChainBytes, c.signer.RestChainRaw()...) signedChainBytes = append(signedChainBytes, c.signer.RawRestChain()...)
err = c.authDB.Claim(&ClaimOpts{ err = c.authDB.Claim(&ClaimOpts{
Req: req, Req: req,
Peer: grpcPeer, Peer: grpcPeer,

View File

@ -675,7 +675,7 @@ func TestCertificateSigner_Sign_E2E(t *testing.T) {
assert.Equal(t, clientIdent.CA.RawTBSCertificate, signedChain[0].RawTBSCertificate) assert.Equal(t, clientIdent.CA.RawTBSCertificate, signedChain[0].RawTBSCertificate)
assert.Equal(t, signingCA.Cert.Raw, signedChainBytes[1]) assert.Equal(t, signingCA.Cert.Raw, signedChainBytes[1])
// TODO: test scenario with rest chain // TODO: test scenario with rest chain
//assert.Equal(t, signingCA.RestChainRaw(), signedChainBytes[1:]) //assert.Equal(t, signingCA.RawRestChain(), signedChainBytes[1:])
err = signedChain[0].CheckSignatureFrom(signingCA.Cert) err = signedChain[0].CheckSignatureFrom(signingCA.Cert)
assert.NoError(t, err) assert.NoError(t, err)
@ -844,7 +844,7 @@ func TestCertificateSigner_Sign(t *testing.T) {
assert.Equal(t, clientIdent.CA.RawTBSCertificate, signedChain[0].RawTBSCertificate) assert.Equal(t, clientIdent.CA.RawTBSCertificate, signedChain[0].RawTBSCertificate)
assert.Equal(t, signingCA.Cert.Raw, signedChain[1].Raw) assert.Equal(t, signingCA.Cert.Raw, signedChain[1].Raw)
// TODO: test scenario with rest chain // TODO: test scenario with rest chain
//assert.Equal(t, signingCA.RestChainRaw(), res.Chain[1:]) //assert.Equal(t, signingCA.RawRestChain(), res.Chain[1:])
err = signedChain[0].CheckSignatureFrom(signingCA.Cert) err = signedChain[0].CheckSignatureFrom(signingCA.Cert)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -380,8 +380,23 @@ func (ca *FullCertificateAuthority) NewIdentity() (*FullIdentity, error) {
} }
// RestChainRaw returns the rest (excluding leaf and CA) of the certificate chain as a 2d byte slice // Chain returns the CA's certificate chain
func (ca *FullCertificateAuthority) RestChainRaw() [][]byte { func (ca *FullCertificateAuthority) Chain() []*x509.Certificate {
return append([]*x509.Certificate{ca.Cert}, ca.RestChain...)
}
// RawChain returns the CA's certificate chain as a 2d byte slice
func (ca *FullCertificateAuthority) RawChain() [][]byte {
chain := ca.Chain()
rawChain := make([][]byte, len(chain))
for i, cert := range chain {
rawChain[i] = cert.Raw
}
return rawChain
}
// RawRestChain returns the "rest" (excluding `ca.Cert`) of the certificate chain as a 2d byte slice
func (ca *FullCertificateAuthority) RawRestChain() [][]byte {
var chain [][]byte var chain [][]byte
for _, cert := range ca.RestChain { for _, cert := range ca.RestChain {
chain = append(chain, cert.Raw) chain = append(chain, cert.Raw)

View File

@ -253,6 +253,15 @@ func NewFullIdentity(ctx context.Context, difficulty uint16, concurrency uint) (
return identity, err return identity, err
} }
// ToChains takes a number of certificate chains and returns them as a 2d slice of chains of certificates.
func ToChains(chains ...[]*x509.Certificate) [][]*x509.Certificate {
combinedChains := make([][]*x509.Certificate, len(chains))
for i, chain := range chains {
combinedChains[i] = chain
}
return combinedChains
}
// Status returns the status of the identity cert/key files for the config // Status returns the status of the identity cert/key files for the config
func (is SetupConfig) Status() TLSFilesStatus { func (is SetupConfig) Status() TLSFilesStatus {
return statTLSFiles(is.CertPath, is.KeyPath) return statTLSFiles(is.CertPath, is.KeyPath)
@ -390,22 +399,28 @@ func (ic PeerConfig) SaveBackup(pi *PeerIdentity) error {
}.Save(pi) }.Save(pi)
} }
// ChainRaw returns all of the certificate chain as a 2d byte slice // Chain returns the Identity's certificate chain
func (fi *FullIdentity) ChainRaw() [][]byte { func (fi *FullIdentity) Chain() []*x509.Certificate {
chain := [][]byte{fi.Leaf.Raw, fi.CA.Raw} return append([]*x509.Certificate{fi.Leaf, fi.CA}, fi.RestChain...)
for _, cert := range fi.RestChain {
chain = append(chain, cert.Raw)
}
return chain
} }
// RestChainRaw returns the rest (excluding leaf and CA) of the certificate chain as a 2d byte slice // RawChain returns all of the certificate chain as a 2d byte slice
func (fi *FullIdentity) RestChainRaw() [][]byte { func (fi *FullIdentity) RawChain() [][]byte {
var chain [][]byte chain := fi.Chain()
for _, cert := range fi.RestChain { rawChain := make([][]byte, len(chain))
chain = append(chain, cert.Raw) for i, cert := range chain {
rawChain[i] = cert.Raw
} }
return chain return rawChain
}
// RawRestChain returns the rest (excluding leaf and CA) of the certificate chain as a 2d byte slice
func (fi *FullIdentity) RawRestChain() [][]byte {
rawChain := make([][]byte, len(fi.RestChain))
for _, cert := range fi.RestChain {
rawChain = append(rawChain, cert.Raw)
}
return rawChain
} }
// PeerIdentity converts a FullIdentity into a PeerIdentity // PeerIdentity converts a FullIdentity into a PeerIdentity

View File

@ -53,7 +53,7 @@ func (dialer *Dialer) Lookup(ctx context.Context, self pb.Node, ask pb.Node, fin
} }
defer dialer.limit.Unlock() defer dialer.limit.Unlock()
conn, err := dialer.dial(ctx, ask) conn, err := dialer.dialNode(ctx, ask)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,14 +71,14 @@ func (dialer *Dialer) Lookup(ctx context.Context, self pb.Node, ask pb.Node, fin
return resp.Response, conn.disconnect() return resp.Response, conn.disconnect()
} }
// Ping pings target. // PingNode pings target.
func (dialer *Dialer) Ping(ctx context.Context, target pb.Node) (bool, error) { func (dialer *Dialer) PingNode(ctx context.Context, target pb.Node) (bool, error) {
if !dialer.limit.Lock() { if !dialer.limit.Lock() {
return false, context.Canceled return false, context.Canceled
} }
defer dialer.limit.Unlock() defer dialer.limit.Unlock()
conn, err := dialer.dial(ctx, target) conn, err := dialer.dialNode(ctx, target)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -88,6 +88,22 @@ func (dialer *Dialer) Ping(ctx context.Context, target pb.Node) (bool, error) {
return err == nil, errs.Combine(err, conn.disconnect()) return err == nil, errs.Combine(err, conn.disconnect())
} }
// PingAddress pings target by address (no node ID verification).
func (dialer *Dialer) PingAddress(ctx context.Context, address string, opts ...grpc.CallOption) (bool, error) {
if !dialer.limit.Lock() {
return false, context.Canceled
}
defer dialer.limit.Unlock()
conn, err := dialer.dialAddress(ctx, address)
if err != nil {
return false, err
}
_, err = conn.client.Ping(ctx, &pb.PingRequest{}, opts...)
return err == nil, errs.Combine(err, conn.disconnect())
}
// FetchPeerIdentity connects to a node and returns its peer identity // FetchPeerIdentity connects to a node and returns its peer identity
func (dialer *Dialer) FetchPeerIdentity(ctx context.Context, target pb.Node) (pID *identity.PeerIdentity, err error) { func (dialer *Dialer) FetchPeerIdentity(ctx context.Context, target pb.Node) (pID *identity.PeerIdentity, err error) {
if !dialer.limit.Lock() { if !dialer.limit.Lock() {
@ -95,7 +111,7 @@ func (dialer *Dialer) FetchPeerIdentity(ctx context.Context, target pb.Node) (pI
} }
defer dialer.limit.Unlock() defer dialer.limit.Unlock()
conn, err := dialer.dial(ctx, target) conn, err := dialer.dialNode(ctx, target)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +132,7 @@ func (dialer *Dialer) FetchInfo(ctx context.Context, target pb.Node) (*pb.InfoRe
} }
defer dialer.limit.Unlock() defer dialer.limit.Unlock()
conn, err := dialer.dial(ctx, target) conn, err := dialer.dialNode(ctx, target)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -126,8 +142,8 @@ func (dialer *Dialer) FetchInfo(ctx context.Context, target pb.Node) (*pb.InfoRe
return resp, errs.Combine(err, conn.disconnect()) return resp, errs.Combine(err, conn.disconnect())
} }
// dial dials the specified node. // dialNode dials the specified node.
func (dialer *Dialer) dial(ctx context.Context, target pb.Node) (*Conn, error) { func (dialer *Dialer) dialNode(ctx context.Context, target pb.Node) (*Conn, error) {
grpcconn, err := dialer.transport.DialNode(ctx, &target) grpcconn, err := dialer.transport.DialNode(ctx, &target)
return &Conn{ return &Conn{
conn: grpcconn, conn: grpcconn,
@ -135,6 +151,15 @@ func (dialer *Dialer) dial(ctx context.Context, target pb.Node) (*Conn, error) {
}, err }, err
} }
// dialAddress dials the specified node by address (no node ID verification)
func (dialer *Dialer) dialAddress(ctx context.Context, address string) (*Conn, error) {
grpcconn, err := dialer.transport.DialAddress(ctx, address)
return &Conn{
conn: grpcconn,
client: pb.NewNodesClient(grpcconn),
}, err
}
// disconnect disconnects this connection. // disconnect disconnects this connection.
func (conn *Conn) disconnect() error { func (conn *Conn) disconnect() error {
return conn.conn.Close() return conn.conn.Close()

View File

@ -27,7 +27,7 @@ func TestDialer(t *testing.T) {
// TODO: also use satellites // TODO: also use satellites
peers := planet.StorageNodes peers := planet.StorageNodes
{ // Ping: storage node pings all other storage nodes { // PingNode: storage node pings all other storage nodes
self := planet.StorageNodes[0] self := planet.StorageNodes[0]
dialer := kademlia.NewDialer(zaptest.NewLogger(t), self.Transport) dialer := kademlia.NewDialer(zaptest.NewLogger(t), self.Transport)
@ -38,7 +38,7 @@ func TestDialer(t *testing.T) {
for _, peer := range peers { for _, peer := range peers {
peer := peer peer := peer
group.Go(func() error { group.Go(func() error {
pinged, err := dialer.Ping(ctx, peer.Local()) pinged, err := dialer.PingNode(ctx, peer.Local())
var pingErr error var pingErr error
if !pinged { if !pinged {
pingErr = fmt.Errorf("ping to %s should have succeeded", peer.ID()) pingErr = fmt.Errorf("ping to %s should have succeeded", peer.ID())

View File

@ -11,6 +11,8 @@ import (
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"storj.io/storj/internal/sync2" "storj.io/storj/internal/sync2"
"storj.io/storj/pkg/identity" "storj.io/storj/pkg/identity"
@ -120,13 +122,28 @@ func (k *Kademlia) Bootstrap(ctx context.Context) error {
} }
var errs errs.Group var errs errs.Group
for _, node := range k.bootstrapNodes { for i, node := range k.bootstrapNodes {
if ctx.Err() != nil { if ctx.Err() != nil {
errs.Add(ctx.Err()) errs.Add(ctx.Err())
return errs.Err() return errs.Err()
} }
_, err := k.dialer.Ping(ctx, node) p := &peer.Peer{}
pCall := grpc.Peer(p)
_, err := k.dialer.PingAddress(ctx, node.Address.Address, pCall)
if err != nil {
errs.Add(err)
}
ident, err := identity.PeerIdentityFromPeer(p)
if err != nil {
errs.Add(err)
}
k.routingTable.mutex.Lock()
node.Id = ident.ID
k.bootstrapNodes[i] = node
k.routingTable.mutex.Unlock()
if err == nil { if err == nil {
// We have pinged successfully one bootstrap node. // We have pinged successfully one bootstrap node.
// Clear any errors and break the cycle. // Clear any errors and break the cycle.
@ -181,7 +198,7 @@ func (k *Kademlia) Ping(ctx context.Context, node pb.Node) (pb.Node, error) {
} }
defer k.lookups.Done() defer k.lookups.Done()
ok, err := k.dialer.Ping(ctx, node) ok, err := k.dialer.PingNode(ctx, node)
if err != nil { if err != nil {
return pb.Node{}, NodeErr.Wrap(err) return pb.Node{}, NodeErr.Wrap(err)
} }

View File

@ -103,7 +103,7 @@ func (lookup *peerDiscovery) Run(ctx context.Context) (target *pb.Node, err erro
if !ok { if !ok {
lookup.log.Debug("connecting to node failed", lookup.log.Debug("connecting to node failed",
zap.Any("target", lookup.target), zap.Any("target", lookup.target),
zap.Any("dial", next.Id), zap.Any("dial-node", next.Id),
zap.Any("dial-address", next.Address.Address), zap.Any("dial-address", next.Address.Address),
zap.Error(err), zap.Error(err),
) )

View File

@ -13,6 +13,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"storj.io/storj/pkg/dht" "storj.io/storj/pkg/dht"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/storj" "storj.io/storj/pkg/storj"

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"storj.io/storj/internal/testcontext" "storj.io/storj/internal/testcontext"
"storj.io/storj/pkg/dht" "storj.io/storj/pkg/dht"
"storj.io/storj/pkg/kademlia/testrouting" "storj.io/storj/pkg/kademlia/testrouting"

View File

@ -88,6 +88,6 @@ func (opts *Options) configure(c Config) (err error) {
opts.PCVFuncs = pcvs opts.PCVFuncs = pcvs
opts.Cert, err = peertls.TLSCert(opts.Ident.ChainRaw(), opts.Ident.Leaf, opts.Ident.Key) opts.Cert, err = peertls.TLSCert(opts.Ident.RawChain(), opts.Ident.Leaf, opts.Ident.Key)
return err return err
} }

View File

@ -15,6 +15,7 @@ import (
"storj.io/storj/internal/testplanet" "storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/peertls" "storj.io/storj/pkg/peertls"
"storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
) )
func TestNewOptions(t *testing.T) { func TestNewOptions(t *testing.T) {
@ -105,3 +106,26 @@ func TestNewOptions(t *testing.T) {
assert.Len(t, opts.PCVFuncs, c.pcvFuncsLen) assert.Len(t, opts.PCVFuncs, c.pcvFuncsLen)
} }
} }
func TestOptions_DialOption_error_on_empty_ID(t *testing.T) {
ident, err := testplanet.PregeneratedIdentity(0)
require.NoError(t, err)
opts, err := tlsopts.NewOptions(ident, tlsopts.Config{})
require.NoError(t, err)
dialOption, err := opts.DialOption(storj.NodeID{})
assert.Nil(t, dialOption)
assert.Error(t, err)
}
func TestOptions_DialUnverifiedIDOption(t *testing.T) {
ident, err := testplanet.PregeneratedIdentity(0)
require.NoError(t, err)
opts, err := tlsopts.NewOptions(ident, tlsopts.Config{})
require.NoError(t, err)
dialOption := opts.DialUnverifiedIDOption()
assert.NotNil(t, dialOption)
}

View File

@ -38,9 +38,16 @@ func (opts *Options) ServerOption() grpc.ServerOption {
// DialOption returns a grpc `DialOption` for making outgoing connections // DialOption returns a grpc `DialOption` for making outgoing connections
// to the node with this peer identity. // to the node with this peer identity.
// id is an optional id of the node we are dialing. func (opts *Options) DialOption(id storj.NodeID) (grpc.DialOption, error) {
func (opts *Options) DialOption(id storj.NodeID) grpc.DialOption { if id.IsZero() {
return grpc.WithTransportCredentials(opts.TransportCredentials(id)) return nil, Error.New("no ID specified for DialOption")
}
return grpc.WithTransportCredentials(opts.TransportCredentials(id)), nil
}
// DialUnverifiedIDOption returns a grpc `DialUnverifiedIDOption`
func (opts *Options) DialUnverifiedIDOption() grpc.DialOption {
return grpc.WithTransportCredentials(opts.TransportCredentials(storj.NodeID{}))
} }
// TransportCredentials returns a grpc `credentials.TransportCredentials` // TransportCredentials returns a grpc `credentials.TransportCredentials`
@ -54,10 +61,12 @@ func (opts *Options) TLSConfig(id storj.NodeID) *tls.Config {
pcvFuncs := append( pcvFuncs := append(
[]peertls.PeerCertVerificationFunc{ []peertls.PeerCertVerificationFunc{
peertls.VerifyPeerCertChains, peertls.VerifyPeerCertChains,
verifyIdentity(id),
}, },
opts.PCVFuncs..., opts.PCVFuncs...,
) )
if !id.IsZero() {
pcvFuncs = append(pcvFuncs, verifyIdentity(id))
}
return &tls.Config{ return &tls.Config{
Certificates: []tls.Certificate{*opts.Cert}, Certificates: []tls.Certificate{*opts.Cert},
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -70,10 +79,6 @@ func (opts *Options) TLSConfig(id storj.NodeID) *tls.Config {
func verifyIdentity(id storj.NodeID) peertls.PeerCertVerificationFunc { func verifyIdentity(id storj.NodeID) peertls.PeerCertVerificationFunc {
return func(_ [][]byte, parsedChains [][]*x509.Certificate) (err error) { return func(_ [][]byte, parsedChains [][]*x509.Certificate) (err error) {
defer mon.TaskNamed("verifyIdentity")(nil)(&err) defer mon.TaskNamed("verifyIdentity")(nil)(&err)
if id == (storj.NodeID{}) {
return nil
}
peer, err := identity.PeerIdentityFromCerts(parsedChains[0][0], parsedChains[0][1], parsedChains[0][2:]) peer, err := identity.PeerIdentityFromCerts(parsedChains[0][0], parsedChains[0][1], parsedChains[0][2:])
if err != nil { if err != nil {
return err return err

View File

@ -0,0 +1,6 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package tlsopts
var VerifyIdentity = verifyIdentity

View File

@ -0,0 +1,60 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package tlsopts_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"storj.io/storj/internal/testplanet"
"storj.io/storj/pkg/identity"
"storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
)
func TestVerifyIdentity_success(t *testing.T) {
for i := 0; i < 50; i++ {
ident, err := testplanet.PregeneratedIdentity(i)
require.NoError(t, err)
err = tlsopts.VerifyIdentity(ident.ID)(nil, identity.ToChains(ident.Chain()))
assert.NoError(t, err)
}
}
func TestVerifyIdentity_success_signed(t *testing.T) {
for i := 0; i < 50; i++ {
ident, err := testplanet.PregeneratedSignedIdentity(i)
require.NoError(t, err)
err = tlsopts.VerifyIdentity(ident.ID)(nil, identity.ToChains(ident.Chain()))
assert.NoError(t, err)
}
}
func TestVerifyIdentity_error(t *testing.T) {
ident, err := testplanet.PregeneratedIdentity(0)
require.NoError(t, err)
identTheftVictim, err := testplanet.PregeneratedIdentity(1)
require.NoError(t, err)
cases := []struct {
test string
nodeID storj.NodeID
}{
{"empty node ID", storj.NodeID{}},
{"garbage node ID", storj.NodeID{0, 1, 2, 3}},
{"wrong node ID", identTheftVictim.ID},
}
for _, c := range cases {
t.Run(c.test, func(t *testing.T) {
err := tlsopts.VerifyIdentity(c.nodeID)(nil, identity.ToChains(ident.Chain()))
assert.Error(t, err)
})
}
}

View File

@ -520,7 +520,8 @@ func NewTest(ctx context.Context, t *testing.T, snID, upID *identity.FullIdentit
tlsOptions, err := tlsopts.NewOptions(upID, tlsopts.Config{}) tlsOptions, err := tlsopts.NewOptions(upID, tlsopts.Config{})
require.NoError(t, err) require.NoError(t, err)
conn, err := grpc.Dial(listener.Addr().String(), tlsOptions.DialOption(storj.NodeID{})) // TODO: why aren't we using transport client here?
conn, err := grpc.Dial(listener.Addr().String(), tlsOptions.DialUnverifiedIDOption())
require.NoError(t, err) require.NoError(t, err)
psClient := pb.NewPieceStoreRoutesClient(conn) psClient := pb.NewPieceStoreRoutesClient(conn)
//cleanup callback //cleanup callback

View File

@ -14,7 +14,6 @@ import (
"storj.io/storj/pkg/identity" "storj.io/storj/pkg/identity"
"storj.io/storj/pkg/pb" "storj.io/storj/pkg/pb"
"storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/peertls/tlsopts"
"storj.io/storj/pkg/storj"
) )
var ( var (
@ -68,8 +67,13 @@ func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ..
return nil, Error.New("no address") return nil, Error.New("no address")
} }
dialOption, err := transport.tlsOpts.DialOption(node.Id)
if err != nil {
return nil, err
}
options := append([]grpc.DialOption{ options := append([]grpc.DialOption{
transport.tlsOpts.DialOption(node.Id), dialOption,
grpc.WithBlock(), grpc.WithBlock(),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
}, opts...) }, opts...)
@ -100,7 +104,7 @@ func (transport *Transport) DialAddress(ctx context.Context, address string, opt
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
options := append([]grpc.DialOption{ options := append([]grpc.DialOption{
transport.tlsOpts.DialOption(storj.NodeID{}), transport.tlsOpts.DialUnverifiedIDOption(),
grpc.WithBlock(), grpc.WithBlock(),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
}, opts...) }, opts...)

View File

@ -74,6 +74,14 @@ func TestDialNode(t *testing.T) {
}, },
Type: pb.NodeType_STORAGE, Type: pb.NodeType_STORAGE,
}, },
{
Id: storj.NodeID{},
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[1].Addr(),
},
Type: pb.NodeType_STORAGE,
},
} }
for _, target := range targets { for _, target := range targets {
@ -98,7 +106,29 @@ func TestDialNode(t *testing.T) {
} }
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{}) conn, err := client.DialNode(timedCtx, target)
cancel()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.NoError(t, conn.Close())
})
t.Run("DialNode with valid signed target", func(t *testing.T) {
target := &pb.Node{
Id: planet.StorageNodes[1].ID(),
Address: &pb.NodeAddress{
Transport: pb.NodeTransport_TCP_TLS_GRPC,
Address: planet.StorageNodes[1].Addr(),
},
Type: pb.NodeType_STORAGE,
}
dialOption, err := opts.DialOption(target.Id)
require.NoError(t, err)
timedCtx, cancel := context.WithTimeout(ctx, time.Second)
conn, err := client.DialNode(timedCtx, target, dialOption) conn, err := client.DialNode(timedCtx, target, dialOption)
cancel() cancel()
@ -119,7 +149,9 @@ func TestDialNode(t *testing.T) {
} }
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := unsignedClientOpts.DialOption(storj.NodeID{}) dialOption, err := unsignedClientOpts.DialOption(target.Id)
require.NoError(t, err)
conn, err := client.DialNode( conn, err := client.DialNode(
timedCtx, target, dialOption, timedCtx, target, dialOption,
) )
@ -133,7 +165,7 @@ func TestDialNode(t *testing.T) {
t.Run("DialAddress with bad client certificate", func(t *testing.T) { t.Run("DialAddress with bad client certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := unsignedClientOpts.DialOption(storj.NodeID{}) dialOption := unsignedClientOpts.DialUnverifiedIDOption()
conn, err := client.DialAddress( conn, err := client.DialAddress(
timedCtx, planet.StorageNodes[1].Addr(), dialOption, timedCtx, planet.StorageNodes[1].Addr(), dialOption,
) )
@ -200,7 +232,9 @@ func TestDialNode_BadServerCertificate(t *testing.T) {
} }
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{}) dialOption, err := opts.DialOption(target.Id)
require.NoError(t, err)
conn, err := client.DialNode(timedCtx, target, dialOption) conn, err := client.DialNode(timedCtx, target, dialOption)
cancel() cancel()
@ -212,7 +246,9 @@ func TestDialNode_BadServerCertificate(t *testing.T) {
t.Run("DialAddress with bad server certificate", func(t *testing.T) { t.Run("DialAddress with bad server certificate", func(t *testing.T) {
timedCtx, cancel := context.WithTimeout(ctx, time.Second) timedCtx, cancel := context.WithTimeout(ctx, time.Second)
dialOption := opts.DialOption(storj.NodeID{}) dialOption, err := opts.DialOption(planet.StorageNodes[1].ID())
require.NoError(t, err)
conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr(), dialOption) conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr(), dialOption)
cancel() cancel()