storj/pkg/peertls/peertls_test.go
Bryan White 2016ce9fd6
Certificate revocation (#836)
* wip certificate revocation

* refactor tests

* wip testing

* testing

* review fixes

* integration fix attempt #1

* review fixes

* integration fix attempt #2

* linter fixes

* add copywrite

* integration fix attemp #3

* more testing

* more tests

* go mod tidy

* review fixes

* linter fixes
2018-12-13 21:01:43 +01:00

774 lines
18 KiB
Go

// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package peertls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/gob"
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zeebo/errs"
)
func TestNewCert_CA(t *testing.T) {
caKey, err := NewKey()
assert.NoError(t, err)
caTemplate, err := CATemplate()
assert.NoError(t, err)
caCert, err := NewCert(caKey, nil, caTemplate, nil)
assert.NoError(t, err)
assert.NotEmpty(t, caKey.(*ecdsa.PrivateKey))
assert.NotEmpty(t, caCert)
assert.NotEmpty(t, caCert.PublicKey.(*ecdsa.PublicKey))
err = caCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
}
func TestNewCert_Leaf(t *testing.T) {
caKey, err := NewKey()
assert.NoError(t, err)
caTemplate, err := CATemplate()
assert.NoError(t, err)
caCert, err := NewCert(caKey, nil, caTemplate, nil)
assert.NoError(t, err)
leafKey, err := NewKey()
assert.NoError(t, err)
leafTemplate, err := LeafTemplate()
assert.NoError(t, err)
leafCert, err := NewCert(leafKey, caKey, leafTemplate, caCert)
assert.NoError(t, err)
assert.NotEmpty(t, caKey.(*ecdsa.PrivateKey))
assert.NotEmpty(t, leafCert)
assert.NotEmpty(t, leafCert.PublicKey.(*ecdsa.PublicKey))
err = caCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
err = leafCert.CheckSignatureFrom(caCert)
assert.NoError(t, err)
}
func TestVerifyPeerFunc(t *testing.T) {
_, chain, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafCert, caCert := chain[0], chain[1]
testFunc := func(chain [][]byte, parsedChains [][]*x509.Certificate) error {
switch {
case !bytes.Equal(chain[1], caCert.Raw):
return errs.New("CA cert doesn't match")
case !bytes.Equal(chain[0], leafCert.Raw):
return errs.New("leaf's CA cert doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).Curve != parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).Curve:
return errs.New("leaf public key doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).X.Cmp(parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).X) != 0:
return errs.New("leaf public key doesn't match")
case leafCert.PublicKey.(*ecdsa.PublicKey).Y.Cmp(parsedChains[0][0].PublicKey.(*ecdsa.PublicKey).Y) != 0:
return errs.New("leaf public key doesn't match")
case !bytes.Equal(parsedChains[0][1].Raw, caCert.Raw):
return errs.New("parsed CA cert doesn't match")
case !bytes.Equal(parsedChains[0][0].Raw, leafCert.Raw):
return errs.New("parsed leaf cert doesn't match")
}
return nil
}
err = VerifyPeerFunc(testFunc)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
}
func TestVerifyPeerCertChains(t *testing.T) {
keys, chain, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafKey, leafCert, caCert := keys[1], chain[0], chain[1]
err = VerifyPeerFunc(VerifyPeerCertChains)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
wrongKey, err := NewKey()
assert.NoError(t, err)
leafCert, err = NewCert(leafKey, wrongKey, leafCert, caCert)
assert.NoError(t, err)
err = VerifyPeerFunc(VerifyPeerCertChains)([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, ErrVerifyPeerCert.Has(err))
assert.True(t, ErrVerifyCertificateChain.Has(err))
}
func TestVerifyCAWhitelist(t *testing.T) {
_, chain2, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
leafCert, caCert := chain2[0], chain2[1]
t.Run("empty whitelist", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist(nil))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("whitelist contains ca", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{caCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
_, unrelatedChain, err := newCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
unrelatedCert := unrelatedChain[0]
t.Run("no valid signed extension, non-empty whitelist", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{unrelatedCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.True(t, ErrVerifyCAWhitelist.Has(err))
})
t.Run("last cert in whitelist is signer", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{unrelatedCert, caCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("first cert in whitelist is signer", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{caCert, unrelatedCert}))([][]byte{leafCert.Raw, caCert.Raw}, nil)
assert.NoError(t, err)
})
_, chain3, err := newCertChain(3)
if !assert.NoError(t, err) {
t.FailNow()
}
leaf2Cert, ca2Cert, rootCert := chain3[0], chain3[1], chain3[2]
t.Run("length 3 chain - first cert in whitelist is signer", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{rootCert, unrelatedCert}))([][]byte{leaf2Cert.Raw, ca2Cert.Raw, unrelatedCert.Raw}, nil)
assert.NoError(t, err)
})
t.Run("length 3 chain - last cert in whitelist is signer", func(t *testing.T) {
err = VerifyPeerFunc(VerifyCAWhitelist([]*x509.Certificate{unrelatedCert, rootCert}))([][]byte{leaf2Cert.Raw, ca2Cert.Raw, unrelatedCert.Raw}, nil)
assert.NoError(t, err)
})
}
func TestAddExtension(t *testing.T) {
_, chain, err := newCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
// NB: there's nothing special about length 32
randBytes := make([]byte, 32)
exampleID := asn1.ObjectIdentifier{2, 999}
i, err := rand.Read(randBytes)
if !assert.NoError(t, err) {
t.FailNow()
}
assert.Equal(t, 32, i)
ext := pkix.Extension{
Id: exampleID,
Value: randBytes,
}
err = AddExtension(chain[0], ext)
assert.NoError(t, err)
assert.Len(t, chain[0].ExtraExtensions, 1)
assert.Equal(t, ext, chain[0].ExtraExtensions[0])
}
func TestAddSignedCertExt(t *testing.T) {
keys, chain, err := newCertChain(1)
if !assert.NoError(t, err) {
t.FailNow()
}
err = AddSignedCertExt(keys[0], chain[0])
assert.NoError(t, err)
assert.Len(t, chain[0].ExtraExtensions, 1)
assert.Equal(t, ExtensionIDs[SignedCertExtID], chain[0].ExtraExtensions[0].Id)
ecKey, ok := keys[0].(*ecdsa.PrivateKey)
if !assert.True(t, ok) {
t.FailNow()
}
err = VerifySignature(
chain[0].ExtraExtensions[0].Value,
chain[0].RawTBSCertificate,
&ecKey.PublicKey,
)
assert.NoError(t, err)
}
func TestSignLeafExt(t *testing.T) {
keys, chain, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
caKey, leafCert := keys[0], chain[0]
err = AddSignedCertExt(caKey, leafCert)
assert.NoError(t, err)
assert.Equal(t, 1, len(leafCert.ExtraExtensions))
assert.True(t, ExtensionIDs[SignedCertExtID].Equal(leafCert.ExtraExtensions[0].Id))
caECKey, ok := caKey.(*ecdsa.PrivateKey)
if !assert.True(t, ok) {
t.FailNow()
}
err = VerifySignature(leafCert.ExtraExtensions[0].Value, leafCert.RawTBSCertificate, &caECKey.PublicKey)
assert.NoError(t, err)
}
func TestRevocation_Sign(t *testing.T) {
keys, chain, err := newCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := hashBytes(leafCert.Raw)
assert.NoError(t, err)
rev := Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
}
func TestRevocation_Verify(t *testing.T) {
keys, chain, err := newCertChain(2)
assert.NoError(t, err)
leafCert, caCert, caKey := chain[0], chain[1], keys[0]
leafHash, err := hashBytes(leafCert.Raw)
assert.NoError(t, err)
rev := Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
err = rev.Verify(caCert)
assert.NoError(t, err)
}
func TestRevocation_Marshal(t *testing.T) {
keys, chain, err := newCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := hashBytes(leafCert.Raw)
assert.NoError(t, err)
rev := Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
assert.NotEmpty(t, revBytes)
decodedRev := new(Revocation)
decoder := gob.NewDecoder(bytes.NewBuffer(revBytes))
err = decoder.Decode(decodedRev)
assert.NoError(t, err)
assert.Equal(t, rev, *decodedRev)
}
func TestRevocation_Unmarshal(t *testing.T) {
keys, chain, err := newCertChain(2)
assert.NoError(t, err)
leafCert, caKey := chain[0], keys[0]
leafHash, err := hashBytes(leafCert.Raw)
assert.NoError(t, err)
rev := Revocation{
Timestamp: time.Now().Unix(),
CertHash: make([]byte, len(leafHash)),
}
copy(rev.CertHash, leafHash)
err = rev.Sign(caKey)
assert.NoError(t, err)
assert.NotEmpty(t, rev.Signature)
encodedRev := new(bytes.Buffer)
encoder := gob.NewEncoder(encodedRev)
err = encoder.Encode(rev)
assert.NoError(t, err)
unmarshaledRev := new(Revocation)
err = unmarshaledRev.Unmarshal(encodedRev.Bytes())
assert.NoError(t, err)
assert.NotNil(t, rev)
assert.Equal(t, rev, *unmarshaledRev)
}
func TestNewRevocationExt(t *testing.T) {
keys, chain, err := newCertChain(2)
assert.NoError(t, err)
ext, err := NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
var rev Revocation
err = rev.Unmarshal(ext.Value)
assert.NoError(t, err)
err = rev.Verify(chain[1])
assert.NoError(t, err)
}
func TestRevocationDB_Get(t *testing.T) {
tmp, err := ioutil.TempDir("", os.TempDir())
defer func() { _ = os.RemoveAll(tmp) }()
keys, chain, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
ext, err := NewRevocationExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
revDB, err := NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
if !assert.NoError(t, err) {
t.FailNow()
}
var rev *Revocation
t.Run("missing key", func(t *testing.T) {
rev, err = revDB.Get(chain)
assert.NoError(t, err)
assert.Nil(t, rev)
})
caHash, err := hashBytes(chain[1].Raw)
if !assert.NoError(t, err) {
t.FailNow()
}
err = revDB.DB.Put(caHash, ext.Value)
if !assert.NoError(t, err) {
t.FailNow()
}
t.Run("existing key", func(t *testing.T) {
rev, err = revDB.Get(chain)
assert.NoError(t, err)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
assert.True(t, bytes.Equal(ext.Value, revBytes))
})
}
func TestRevocationDB_Put(t *testing.T) {
tmp, err := ioutil.TempDir("", os.TempDir())
defer func() { _ = os.RemoveAll(tmp) }()
keys, chain, err := newCertChain(2)
if !assert.NoError(t, err) {
t.FailNow()
}
olderExt, err := NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
time.Sleep(1 * time.Second)
ext, err := NewRevocationExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
time.Sleep(1 * time.Second)
newerExt, err := NewRevocationExt(keys[0], chain[0])
assert.NoError(t, err)
revDB, err := NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
if !assert.NoError(t, err) {
t.FailNow()
}
cases := []struct {
testID string
ext pkix.Extension
errClass *errs.Class
err error
}{
{
"new key",
ext,
nil,
nil,
},
{
"existing key - older timestamp",
olderExt,
&ErrExtension,
ErrRevocationTimestamp,
},
{
"existing key - newer timestamp",
newerExt,
nil,
nil,
},
}
for _, c := range cases {
t.Run(c.testID, func(t2 *testing.T) {
if !assert.NotNil(t, c.ext) {
t2.Fail()
t.FailNow()
}
err = revDB.Put(chain, c.ext)
if c.errClass != nil {
assert.True(t, c.errClass.Has(err))
}
if c.err != nil {
assert.Equal(t, c.err, err)
}
if c.err == nil && c.errClass == nil {
if !assert.NoError(t2, err) {
t2.Fail()
t.FailNow()
}
func(t2 *testing.T, ext pkix.Extension) {
caHash, err := hashBytes(chain[1].Raw)
if !assert.NoError(t2, err) {
t2.FailNow()
}
revBytes, err := revDB.DB.Get(caHash)
if !assert.NoError(t2, err) {
t2.FailNow()
}
rev := new(Revocation)
err = rev.Unmarshal(revBytes)
assert.NoError(t2, err)
assert.True(t2, bytes.Equal(ext.Value, revBytes))
}(t2, c.ext)
}
})
}
}
type extensionHandlerMock struct {
mock.Mock
}
func (m *extensionHandlerMock) verify(ext pkix.Extension, chain [][]*x509.Certificate) error {
args := m.Called(ext, chain)
return args.Error(0)
}
func TestExtensionHandlers_VerifyFunc(t *testing.T) {
keys, chain, err := newRevokedLeafChain()
chains := [][]*x509.Certificate{chain}
if !assert.NoError(t, err) {
t.FailNow()
}
err = AddSignedCertExt(keys[0], chain[0])
if !assert.NoError(t, err) {
t.FailNow()
}
extMock := new(extensionHandlerMock)
verify := func(ext pkix.Extension, chain [][]*x509.Certificate) error {
return extMock.verify(ext, chain)
}
handlers := ExtensionHandlers{
{
id: ExtensionIDs[RevocationExtID],
verify: verify,
},
{
id: ExtensionIDs[SignedCertExtID],
verify: verify,
},
}
extMock.On("verify", chains[0][LeafIndex].ExtraExtensions[0], chains).Return(nil)
extMock.On("verify", chains[0][LeafIndex].ExtraExtensions[1], chains).Return(nil)
err = handlers.VerifyFunc()(nil, chains)
assert.NoError(t, err)
extMock.AssertCalled(t, "verify", chains[0][LeafIndex].ExtraExtensions[0], chains)
extMock.AssertCalled(t, "verify", chains[0][LeafIndex].ExtraExtensions[1], chains)
extMock.AssertExpectations(t)
// TODO: test error scenario(s)
}
func TestParseExtensions(t *testing.T) {
revokedLeafKeys, revokedLeafChain, err := newRevokedLeafChain()
assert.NoError(t, err)
whitelistSignedKeys, whitelistSignedChain, err := newCertChain(3)
assert.NoError(t, err)
err = AddSignedCertExt(whitelistSignedKeys[0], whitelistSignedChain[0])
assert.NoError(t, err)
_, unrelatedChain, err := newCertChain(1)
assert.NoError(t, err)
tmp, err := ioutil.TempDir("", os.TempDir())
if err != nil {
t.FailNow()
}
defer func() { _ = os.RemoveAll(tmp) }()
revDB, err := NewRevocationDBBolt(filepath.Join(tmp, "revocations.db"))
assert.NoError(t, err)
cases := []struct {
testID string
config TLSExtConfig
extLen int
certChain []*x509.Certificate
whitelist []*x509.Certificate
errClass *errs.Class
err error
}{
{
"leaf whitelist signature - success",
TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
[]*x509.Certificate{whitelistSignedChain[2]},
nil,
nil,
},
{
"leaf whitelist signature - failure (empty whitelist)",
TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
nil,
&ErrVerifyCAWhitelist,
nil,
},
{
"leaf whitelist signature - failure",
TLSExtConfig{WhitelistSignedLeaf: true},
1,
whitelistSignedChain,
unrelatedChain,
&ErrVerifyCAWhitelist,
nil,
},
{
"certificate revocation - single revocation ",
TLSExtConfig{Revocation: true},
1,
revokedLeafChain,
nil,
nil,
nil,
},
{
"certificate revocation - serial revocations",
TLSExtConfig{Revocation: true},
1,
func() []*x509.Certificate {
rev := new(Revocation)
time.Sleep(1 * time.Second)
_, chain, err := revokeLeaf(revokedLeafKeys, revokedLeafChain)
assert.NoError(t, err)
err = rev.Unmarshal(chain[0].ExtraExtensions[0].Value)
assert.NoError(t, err)
return chain
}(),
nil,
nil,
nil,
},
{
"certificate revocation - serial revocations error (older timestamp)",
TLSExtConfig{Revocation: true},
1,
func() []*x509.Certificate {
keys, chain, err := newRevokedLeafChain()
assert.NoError(t, err)
rev := new(Revocation)
err = rev.Unmarshal(chain[0].ExtraExtensions[0].Value)
assert.NoError(t, err)
rev.Timestamp = rev.Timestamp + 300
err = rev.Sign(keys[0])
assert.NoError(t, err)
revBytes, err := rev.Marshal()
assert.NoError(t, err)
err = revDB.Put(chain, pkix.Extension{
Id: ExtensionIDs[RevocationExtID],
Value: revBytes,
})
assert.NoError(t, err)
return chain
}(),
nil,
&ErrExtension,
ErrRevocationTimestamp,
},
{
"certificate revocation and leaf whitelist signature",
TLSExtConfig{Revocation: true, WhitelistSignedLeaf: true},
2,
func() []*x509.Certificate {
_, chain, err := newRevokedLeafChain()
assert.NoError(t, err)
err = AddSignedCertExt(whitelistSignedKeys[0], chain[0])
assert.NoError(t, err)
return chain
}(),
[]*x509.Certificate{whitelistSignedChain[2]},
nil,
nil,
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
opts := ParseExtOptions{
CAWhitelist: c.whitelist,
RevDB: revDB,
}
handlers := ParseExtensions(c.config, opts)
assert.Equal(t, c.extLen, len(handlers))
err := handlers.VerifyFunc()(nil, [][]*x509.Certificate{c.certChain})
if c.errClass != nil {
assert.True(t, c.errClass.Has(err))
}
if c.err != nil {
assert.NotNil(t, err)
}
if c.errClass == nil && c.err == nil {
assert.NoError(t, err)
}
})
}
}
// NB: keys are in the reverse order compared to certs (i.e. first key belongs to last cert)!
func newCertChain(length int) (keys []crypto.PrivateKey, certs []*x509.Certificate, _ error) {
for i := 0; i < length; i++ {
key, err := NewKey()
if err != nil {
return nil, nil, err
}
keys = append(keys, key)
var template *x509.Certificate
if i == length-1 {
template, err = CATemplate()
} else {
template, err = LeafTemplate()
}
if err != nil {
return nil, nil, err
}
var cert *x509.Certificate
if i == 0 {
cert, err = NewCert(key, nil, template, nil)
} else {
cert, err = NewCert(key, keys[i-1], template, certs[i-1:][0])
}
if err != nil {
return nil, nil, err
}
certs = append([]*x509.Certificate{cert}, certs...)
}
return keys, certs, nil
}
func revokeLeaf(keys []crypto.PrivateKey, chain []*x509.Certificate) ([]crypto.PrivateKey, []*x509.Certificate, error) {
revokingKey, err := NewKey()
if err != nil {
return nil, nil, err
}
revokingTemplate, err := LeafTemplate()
if err != nil {
return nil, nil, err
}
revokingCert, err := NewCert(revokingKey, keys[0], revokingTemplate, chain[1])
if err != nil {
return nil, nil, err
}
err = AddRevocationExt(keys[0], chain[0], revokingCert)
if err != nil {
return nil, nil, err
}
return keys, append([]*x509.Certificate{revokingCert}, chain[1:]...), nil
}
func newRevokedLeafChain() ([]crypto.PrivateKey, []*x509.Certificate, error) {
keys2, certs2, err := newCertChain(2)
if err != nil {
return nil, nil, err
}
return revokeLeaf(keys2, certs2)
}