storj/pkg/peertls/peertls_test.go
2019-01-24 15:15:10 -05:00

744 lines
19 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package peertls_test
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/gob"
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zeebo/errs"
"storj.io/storj/internal/testpeertls"
"storj.io/storj/pkg/peertls"
)
func TestNewCert_CA(t *testing.T) {
caKey, err := peertls.NewKey()
assert.NoError(t, err)
caTemplate, err := peertls.CATemplate()
assert.NoError(t, err)
caCert, err := peertls.NewCert(caKey, nil, caTemplate, nil)
assert.NoError(t, err)
assert.NotEmpty(t, caKey)
assert.NotEmpty(t, caCert)
assert.NotEmpty(t, caCert.PublicKey.(*ecdsa.PublicKey))
err = caCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
}
func TestNewCert_Leaf(t *testing.T) {
caKey, err := peertls.NewKey()
assert.NoError(t, err)
caTemplate, err := peertls.CATemplate()
assert.NoError(t, err)
caCert, err := peertls.NewCert(caKey, nil, caTemplate, nil)
assert.NoError(t, err)
leafKey, err := peertls.NewKey()
assert.NoError(t, err)
leafTemplate, err := peertls.LeafTemplate()
assert.NoError(t, err)
leafCert, err := peertls.NewCert(leafKey, caKey, leafTemplate, caCert)
assert.NoError(t, err)
assert.NotEmpty(t, caKey)
assert.NotEmpty(t, leafCert)
assert.NotEmpty(t, leafCert.PublicKey.(*ecdsa.PublicKey))
err = caCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
err = leafCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
}
func TestVerifyPeerFunc(t *testing.T) {
_, chain, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafCert, caCert := chain[0], chain[1]
testFunc := func(chain [][]byte, parsedChains [][]*x509.Certificate) error {
switch {
case !bytes.Equal(chain[1], caCert.Raw):
return errs.New("CA cert doesn't match")
case !bytes.Equal(chain[0], leafCert.Raw):
return errs.New("leaf's CA cert doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).Curve != parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).Curve:
return errs.New("leaf public key doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).X.Cmp(parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).X) != 0:
return errs.New("leaf public key doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).Y.Cmp(parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).Y) != 0:
return errs.New("leaf public key doesn't match")
case !bytes.Equal(parsedChains[0][1].Raw, caCert.Raw):
return errs.New("parsed CA cert doesn't match")
case !bytes.Equal(parsedChains[0][0].Raw, leafCert.Raw):
return errs.New("parsed leaf cert doesn't match")
}
return nil
}
err = peertls.VerifyPeerFunc(testFunc)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
}
func TestVerifyPeerCertChains(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafKey, leafCert, caCert := keys[1], chain[0], chain[1]
err = peertls.VerifyPeerFunc(peertls.VerifyPeerCertChains)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
wrongKey, err := peertls.NewKey()
assert.NoError(t, err)
leafCert, err = peertls.NewCert(leafKey, wrongKey, leafCert, caCert)
assert.NoError(t, err)
err = peertls.VerifyPeerFunc(peertls.VerifyPeerCertChains)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, peertls.ErrVerifyPeerCert.Has(err))
assert.True(t, peertls.ErrVerifyCertificateChain.Has(err))
}
func TestVerifyCAWhitelist(t *testing.T) {
_, chain2, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafCert, caCert := chain2[0], chain2[1]
t.Run("empty whitelist", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist(nil))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("whitelist contains ca", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{caCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
_, unrelatedChain, err := testpeertls.NewCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
unrelatedCert := unrelatedChain[0]
t.Run("no valid signed extension, non-empty whitelist", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{unrelatedCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, peertls.ErrVerifyCAWhitelist.Has(err))
})
t.Run("last cert in whitelist is signer", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{unrelatedCert, caCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("first cert in whitelist is signer", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{caCert, unrelatedCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
_, chain3, err := testpeertls.NewCertChain(3)
if !assert.NoError(t, err) {
t.FailNow()
}
leaf2Cert, ca2Cert, rootCert := chain3[0], chain3[1], chain3[2]
t.Run("length 3 chain - first cert in whitelist is signer", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{rootCert, unrelatedCert}))([][]byte{leaf2Cert.Raw, ca2Cert.Raw, unrelatedCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("length 3 chain - last cert in whitelist is signer", func(t *testing.T) {
err = peertls.VerifyPeerFunc(peertls.VerifyCAWhitelist([]*x509.Certificate{unrelatedCert, rootCert}))([][]byte{leaf2Cert.Raw, ca2Cert.Raw, unrelatedCert.Raw}, nil)
assert.NoError(t, err)
})
}
func TestAddExtension(t *testing.T) {
_, chain, err := testpeertls.NewCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
// NB: there's nothing special about length 32
randBytes := make([]byte, 32)
exampleID := asn1.ObjectIdentifier{2, 999}
i, err := rand.Read(randBytes)
if !assert.NoError(t, err) {
t.FailNow()
}
assert.Equal(t, 32, i)
ext := pkix.Extension{
Id: exampleID,
Value: randBytes,
}
err = peertls.AddExtension(chain[0], ext)
assert.NoError(t, err)
assert.Len(t, chain[0].ExtraExtensions, 1)
assert.Equal(t, ext, chain[0].ExtraExtensions[0])
}
func TestAddSignedCertExt(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
err = peertls.AddSignedCertExt(keys[0], chain[0])
assert.NoError(t, err)
assert.Len(t, chain[0].ExtraExtensions, 1)
assert.Equal(t, peertls.ExtensionIDs[peertls.SignedCertExtID], chain[0].ExtraExtensions[0].Id)
ecKey, ok := keys[0].(*ecdsa.PrivateKey)
if !assert.True(t, ok) {
t.FailNow()
}
err = peertls.VerifySignature(
chain[0].ExtraExtensions[0].Value,
chain[0].RawTBSCertificate,
&ecKey.PublicKey,
)
assert.NoError(t, err)
}
func TestSignLeafExt(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
caKey, leafCert := keys[0], chain[0]
err = peertls.AddSignedCertExt(caKey, leafCert)
assert.NoError(t, err)
assert.Equal(t, 1, len(leafCert.ExtraExtensions))
assert.True(t, peertls.ExtensionIDs[peertls.SignedCertExtID].Equal(leafCert.ExtraExtensions[0].Id))
caECKey, ok := caKey.(*ecdsa.PrivateKey)
if !assert.True(t, ok) {
t.FailNow()
}
err = peertls.VerifySignature(leafCert.ExtraExtensions[0].Value, leafCert.RawTBSCertificate, &caECKey.PublicKey)
assert.NoError(t, err)
}
func TestRevocation_Sign(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := peertls.SHA256Hash(leafCert.Raw)
assert.NoError(t, err)
rev := peertls.Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
}
func TestRevocation_Verify(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
assert.NoError(t, err)
leafCert, caCert, caKey := chain[0], chain[1], keys[0]
leafHash, err := peertls.SHA256Hash(leafCert.Raw)
assert.NoError(t, err)
rev := peertls.Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
err = rev.Verify(caCert)
assert.NoError(t, err)
}
func TestRevocation_Marshal(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := peertls.SHA256Hash(leafCert.Raw)
assert.NoError(t, err)
rev := peertls.Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
assert.NotEmpty(t, revBytes)
decodedRev := new(peertls.Revocation)
decoder := gob.NewDecoder(bytes.NewBuffer(revBytes))
err = decoder.Decode(decodedRev)
assert.NoError(t, err)
assert.Equal(t, rev, *decodedRev)
}
func TestRevocation_Unmarshal(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := peertls.SHA256Hash(leafCert.Raw)
assert.NoError(t, err)
rev := peertls.Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
encodedRev := new(bytes.Buffer)
encoder := gob.NewEncoder(encodedRev)
err = encoder.Encode(rev)
assert.NoError(t, err)
unmarshaledRev := new(peertls.Revocation)
err = unmarshaledRev.Unmarshal(encodedRev.Bytes())
assert.NoError(t, err)
assert.NotNil(t, rev)
assert.Equal(t, rev, *unmarshaledRev)
}
func TestNewRevocationExt(t *testing.T) {
keys, chain, err := testpeertls.NewCertChain(2)
assert.NoError(t, err)
ext, err := peertls.NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
var rev peertls.Revocation
err = rev.Unmarshal(ext.Value)
assert.NoError(t, err)
err = rev.Verify(chain[1])
assert.NoError(t, err)
}
func TestRevocationDB_Get(t *testing.T) {
tmp, err := ioutil.TempDir("", "TestRevocationDB_Get")
defer func() { _ = os.RemoveAll(tmp) }()
keys, chain, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
ext, err := peertls.NewRevocationExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
revDB, err := peertls.NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
if !assert.NoError(t, err) {
t.FailNow()
}
var rev *peertls.Revocation
t.Run("missing key", func(t *testing.T) {
rev, err = revDB.Get(chain)
assert.NoError(t, err)
assert.Nil(t, rev)
})
caHash, err := peertls.SHA256Hash(chain[1].Raw)
if !assert.NoError(t, err) {
t.FailNow()
}
err = revDB.DB.Put(caHash, ext.Value)
if !assert.NoError(t, err) {
t.FailNow()
}
t.Run("existing key", func(t *testing.T) {
rev, err = revDB.Get(chain)
assert.NoError(t, err)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
assert.True(t, bytes.Equal(ext.Value, revBytes))
})
}
func TestRevocationDB_Put(t *testing.T) {
tmp, err := ioutil.TempDir("", "TestRevocationDB_Put")
defer func() { _ = os.RemoveAll(tmp) }()
keys, chain, err := testpeertls.NewCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
olderExt, err := peertls.NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
time.Sleep(1 * time.Second)
ext, err := peertls.NewRevocationExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
time.Sleep(1 * time.Second)
newerExt, err := peertls.NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
revDB, err := peertls.NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
if !assert.NoError(t, err) {
t.FailNow()
}
cases := []struct {
testID string
ext pkix.Extension
errClass *errs.Class
err error
}{
{
"new key",
ext,
nil,
nil,
},
{
"existing key - older timestamp",
olderExt,
&peertls.ErrExtension,
peertls.ErrRevocationTimestamp,
},
{
"existing key - newer timestamp",
newerExt,
nil,
nil,
},
// TODO(bryanchriswhite): test empty/garbage cert/timestamp/sig
}
for _, c := range cases {
t.Run(c.testID, func(t2 *testing.T) {
if !assert.NotNil(t, c.ext) {
t2.Fail()
t.FailNow()
}
err = revDB.Put(chain, c.ext)
if c.errClass != nil {
assert.True(t, c.errClass.Has(err))
}
if c.err != nil {
assert.Equal(t, c.err, err)
}
if c.err == nil && c.errClass == nil {
if !assert.NoError(t2, err) {
t2.Fail()
t.FailNow()
}
func(t2 *testing.T, ext pkix.Extension) {
caHash, err := peertls.SHA256Hash(chain[1].Raw)
if !assert.NoError(t2, err) {
t2.FailNow()
}
revBytes, err := revDB.DB.Get(caHash)
if !assert.NoError(t2, err) {
t2.FailNow()
}
rev := new(peertls.Revocation)
err = rev.Unmarshal(revBytes)
assert.NoError(t2, err)
assert.True(t2, bytes.Equal(ext.Value, revBytes))
}(t2, c.ext)
}
})
}
}
type extensionHandlerMock struct {
mock.Mock
}
func (m *extensionHandlerMock) verify(ext pkix.Extension, chain [][]*x509.Certificate) error {
args := m.Called(ext, chain)
return args.Error(0)
}
func TestExtensionHandlers_VerifyFunc(t *testing.T) {
keys, chain, err := newRevokedLeafChain()
chains := [][]*x509.Certificate{chain}
if !assert.NoError(t, err) {
t.FailNow()
}
err = peertls.AddSignedCertExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
extMock := new(extensionHandlerMock)
verify := func(ext pkix.Extension, chain [][]*x509.Certificate) error {
return extMock.verify(ext, chain)
}
handlers := peertls.ExtensionHandlers{
{
ID: peertls.ExtensionIDs[peertls.RevocationExtID],
Verify: verify,
},
{
ID: peertls.ExtensionIDs[peertls.SignedCertExtID],
Verify: verify,
},
}
extMock.On("verify", chains[0][peertls.LeafIndex].ExtraExtensions[0], chains).Return(nil)
extMock.On("verify", chains[0][peertls.LeafIndex].ExtraExtensions[1], chains).Return(nil)
err = handlers.VerifyFunc()(nil, chains)
assert.NoError(t, err)
extMock.AssertCalled(t, "verify", chains[0][peertls.LeafIndex].ExtraExtensions[0], chains)
extMock.AssertCalled(t, "verify", chains[0][peertls.LeafIndex].ExtraExtensions[1], chains)
extMock.AssertExpectations(t)
// TODO: test error scenario(s)
}
func TestParseExtensions(t *testing.T) {
revokedLeafKeys, revokedLeafChain, err := newRevokedLeafChain()
assert.NoError(t, err)
whitelistSignedKeys, whitelistSignedChain, err := testpeertls.NewCertChain(3)
assert.NoError(t, err)
err = peertls.AddSignedCertExt(whitelistSignedKeys[0], whitelistSignedChain[0])
assert.NoError(t, err)
_, unrelatedChain, err := testpeertls.NewCertChain(1)
assert.NoError(t, err)
tmp, err := ioutil.TempDir("", "TestParseExtensions")
if err != nil {
t.FailNow()
}
defer func() { _ = os.RemoveAll(tmp) }()
revDB, err := peertls.NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
assert.NoError(t, err)
cases := []struct {
testID string
config peertls.TLSExtConfig
extLen int
certChain []*x509.Certificate
whitelist []*x509.Certificate
errClass *errs.Class
err error
}{
{
"leaf whitelist signature - success",
peertls.TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
[]*x509.Certificate{whitelistSignedChain[2]},
nil,
nil,
},
{
"leaf whitelist signature - failure (empty whitelist)",
peertls.TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
nil,
&peertls.ErrVerifyCAWhitelist,
nil,
},
{
"leaf whitelist signature - failure",
peertls.TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
unrelatedChain,
&peertls.ErrVerifyCAWhitelist,
nil,
},
{
"certificate revocation - single revocation ",
peertls.TLSExtConfig{Revocation: true},
1,
revokedLeafChain,
nil,
nil,
nil,
},
{
"certificate revocation - serial revocations",
peertls.TLSExtConfig{Revocation: true},
1,
func() []*x509.Certificate {
rev := new(peertls.Revocation)
time.Sleep(1 * time.Second)
_, chain, err := revokeLeaf(revokedLeafKeys, revokedLeafChain)
assert.NoError(t, err)
err = rev.Unmarshal(chain[0].ExtraExtensions[0].Value)
assert.NoError(t, err)
return chain
}(),
nil,
nil,
nil,
},
{
"certificate revocation - serial revocations error (older timestamp)",
peertls.TLSExtConfig{Revocation: true},
1,
func() []*x509.Certificate {
keys, chain, err := newRevokedLeafChain()
assert.NoError(t, err)
rev := new(peertls.Revocation)
err = rev.Unmarshal(chain[0].ExtraExtensions[0].Value)
assert.NoError(t, err)
rev.Timestamp = rev.Timestamp + 300
err = rev.Sign(keys[0])
assert.NoError(t, err)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
err = revDB.Put(chain, pkix.Extension{
Id: peertls.ExtensionIDs[peertls.RevocationExtID],
Value: revBytes,
})
assert.NoError(t, err)
return chain
}(),
nil,
&peertls.ErrExtension,
peertls.ErrRevocationTimestamp,
},
{
"certificate revocation and leaf whitelist signature",
peertls.TLSExtConfig{Revocation: true, WhitelistSignedLeaf: true},
2,
func() []*x509.Certificate {
_, chain, err := newRevokedLeafChain()
assert.NoError(t, err)
err = peertls.AddSignedCertExt(whitelistSignedKeys[0], chain[0])
assert.NoError(t, err)
return chain
}(),
[]*x509.Certificate{whitelistSignedChain[2]},
nil,
nil,
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
opts := peertls.ParseExtOptions{
CAWhitelist: c.whitelist,
RevDB: revDB,
}
handlers := peertls.ParseExtensions(c.config, opts)
assert.Equal(t, c.extLen, len(handlers))
err := handlers.VerifyFunc()(nil, [][]*x509.Certificate{c.certChain})
if c.errClass != nil {
assert.True(t, c.errClass.Has(err))
}
if c.err != nil {
assert.NotNil(t, err)
}
if c.errClass == nil && c.err == nil {
assert.NoError(t, err)
}
})
}
}
func revokeLeaf(keys []crypto.PrivateKey, chain []*x509.Certificate) ([]crypto.PrivateKey, []*x509.Certificate, error) {
revokingKey, err := peertls.NewKey()
if err != nil {
return nil, nil, err
}
revokingTemplate, err := peertls.LeafTemplate()
if err != nil {
return nil, nil, err
}
revokingCert, err := peertls.NewCert(revokingKey, keys[0], revokingTemplate, chain[1])
if err != nil {
return nil, nil, err
}
err = peertls.AddRevocationExt(keys[0], chain[0], revokingCert)
if err != nil {
return nil, nil, err
}
return keys, append([]*x509.Certificate{revokingCert}, chain[1:]...), nil
}
func newRevokedLeafChain() ([]crypto.PrivateKey, []*x509.Certificate, error) {
keys2, certs2, err := testpeertls.NewCertChain(2)
if err != nil {
return nil, nil, err
}
return revokeLeaf(keys2, certs2)
}