storj/pkg/certificates/certificates_test.go
2018-12-31 10:45:43 -05:00

378 lines
7.5 KiB
Go

// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
package certificates
import (
"bytes"
"encoding/gob"
"fmt"
"path/filepath"
"testing"
"time"
"github.com/btcsuite/btcutil/base58"
"github.com/stretchr/testify/assert"
"github.com/zeebo/errs"
"storj.io/storj/internal/testcontext"
"storj.io/storj/pkg/utils"
"storj.io/storj/storage"
)
var (
t1 = Token{
UserID: "user@example.com",
Data: [tokenDataLength]byte{1, 2, 3},
}
t2 = Token{
UserID: "user2@example.com",
Data: [tokenDataLength]byte{4, 5, 6},
}
)
func TestCertSignerConfig_NewAuthDB(t *testing.T) {
ctx := testcontext.New(t)
authDB, err := newTestAuthDB(ctx)
if !assert.NoError(t, err) {
t.Fatal(err)
}
defer func() {
_ = authDB.Close()
ctx.Cleanup()
}()
assert.NotNil(t, authDB)
assert.NotNil(t, authDB.DB)
}
func TestAuthorizationDB_Create(t *testing.T) {
ctx := testcontext.New(t)
authDB, err := newTestAuthDB(ctx)
if !assert.NoError(t, err) {
t.Fatal(err)
}
defer func() {
_ = authDB.Close()
ctx.Cleanup()
}()
cases := []struct {
testID,
email string
startCount,
incCount,
newCount,
endCount int
errClass *errs.Class
err error
}{
{
"first authorization",
"user1@example.com",
0, 1, 1, 1,
nil, nil,
},
{
"second authorization",
"user1@example.com",
1, 2, 2, 3,
nil, nil,
},
{
"large authorization",
"user2@example.com",
0, 5, 5, 5,
nil, nil,
},
{
"authorization error",
"user2@example.com",
5, -1, 0, 5,
&ErrAuthorizationDB, ErrAuthorizationCount,
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
emailKey := storage.Key(c.email)
if c.startCount == 0 {
_, err = authDB.DB.Get(emailKey)
assert.Error(t, err)
} else {
v, err := authDB.DB.Get(emailKey)
assert.NoError(t, err)
assert.NotEmpty(t, v)
var existingAuths Authorizations
err = existingAuths.Unmarshal(v)
assert.NoError(t, err)
if !assert.Len(t, existingAuths, c.startCount) {
t.FailNow()
}
}
expectedAuths, err := authDB.Create(c.email, c.incCount)
if c.errClass != nil {
assert.True(t, c.errClass.Has(err))
}
if c.err != nil {
assert.Equal(t, c.err, err)
}
if c.errClass == nil && c.err == nil {
assert.NoError(t, err)
}
assert.Len(t, expectedAuths, c.newCount)
v, err := authDB.DB.Get(emailKey)
assert.NoError(t, err)
assert.NotEmpty(t, v)
var actualAuths Authorizations
err = actualAuths.Unmarshal(v)
assert.NoError(t, err)
assert.Len(t, actualAuths, c.endCount)
})
}
}
func TestAuthorizationDB_Get(t *testing.T) {
ctx := testcontext.New(t)
authDB, err := newTestAuthDB(ctx)
if !assert.NoError(t, err) {
t.Fatal(err)
}
defer func() {
_ = authDB.Close()
ctx.Cleanup()
}()
var expectedAuths Authorizations
for i := 0; i < 5; i++ {
expectedAuths = append(expectedAuths, &Authorization{
Token: t1,
})
}
authsBytes, err := expectedAuths.Marshal()
if !assert.NoError(t, err) {
t.Fatal(err)
}
err = authDB.DB.Put(storage.Key("user@example.com"), authsBytes)
if !assert.NoError(t, err) {
t.Fatal(err)
}
cases := []struct {
testID,
email string
result Authorizations
}{
{
"Non-existant email",
"nouser@example.com",
nil,
},
{
"Exising email",
"user@example.com",
expectedAuths,
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
auths, err := authDB.Get(c.email)
assert.NoError(t, err)
if c.result != nil {
assert.NotEmpty(t, auths)
assert.Len(t, auths, len(c.result))
} else {
assert.Empty(t, auths)
}
})
}
}
func TestNewAuthorization(t *testing.T) {
userID := "user@example.com"
auth, err := NewAuthorization(userID)
assert.NoError(t, err)
if !assert.NotNil(t, auth) {
t.FailNow()
}
assert.NotZero(t, auth.Token)
assert.Equal(t, userID, auth.Token.UserID)
assert.NotEmpty(t, auth.Token.Data)
}
func TestAuthorizations_Marshal(t *testing.T) {
expectedAuths := Authorizations{
{Token: t1},
{Token: t2},
}
authsBytes, err := expectedAuths.Marshal()
assert.NoError(t, err)
assert.NotEmpty(t, authsBytes)
var actualAuths Authorizations
decoder := gob.NewDecoder(bytes.NewBuffer(authsBytes))
err = decoder.Decode(&actualAuths)
assert.NoError(t, err)
assert.NotNil(t, actualAuths)
assert.Equal(t, expectedAuths, actualAuths)
}
func TestAuthorizations_Unmarshal(t *testing.T) {
expectedAuths := Authorizations{
{Token: t1},
{Token: t2},
}
authsBytes, err := expectedAuths.Marshal()
assert.NoError(t, err)
assert.NotEmpty(t, authsBytes)
var actualAuths Authorizations
err = actualAuths.Unmarshal(authsBytes)
assert.NoError(t, err)
assert.NotNil(t, actualAuths)
assert.Equal(t, expectedAuths, actualAuths)
}
func TestAuthorizations_Group(t *testing.T) {
auths := make(Authorizations, 10)
for i := 0; i < 10; i++ {
if i%2 == 0 {
auths[i] = &Authorization{
Token: t1,
Claim: &Claim{
Timestamp: time.Now().Unix(),
},
}
} else {
auths[i] = &Authorization{
Token: t2,
}
}
}
claimed, open := auths.Group()
for _, a := range claimed {
assert.NotNil(t, a.Claim)
}
for _, a := range open {
assert.Nil(t, a.Claim)
}
}
func TestAuthorizationDB_Emails(t *testing.T) {
ctx := testcontext.New(t)
authDB, err := newTestAuthDB(ctx)
if !assert.NoError(t, err) {
t.Fatal(err)
}
defer func() {
_ = authDB.Close()
ctx.Cleanup()
}()
var authErrs utils.ErrorGroup
for i := 0; i < 5; i++ {
_, err := authDB.Create(fmt.Sprintf("user%d@example.com", i), 1)
if err != nil {
authErrs.Add(err)
}
}
err = authErrs.Finish()
if !assert.NoError(t, err) {
t.Fatal(err)
}
userIDs, err := authDB.UserIDs()
assert.NoError(t, err)
assert.NotEmpty(t, userIDs)
}
func TestParseToken_Valid(t *testing.T) {
userID := "user@example.com"
data := [tokenDataLength]byte{1, 2, 3}
cases := []struct {
testID string
userID string
}{
{
"valid token",
userID,
},
{
"multiple delimiters",
"us" + tokenDelimiter + "er@example.com",
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
b58Data := base58.CheckEncode(data[:], tokenVersion)
tokenString := c.userID + tokenDelimiter + b58Data
token, err := ParseToken(tokenString)
assert.NoError(t, err)
if !assert.NotNil(t, token) {
t.FailNow()
}
assert.Equal(t, c.userID, token.UserID)
assert.Equal(t, data[:], token.Data[:])
})
}
}
func TestParseToken_Invalid(t *testing.T) {
userID := "user@example.com"
data := [tokenDataLength]byte{1, 2, 3}
cases := []struct {
testID string
tokenString string
}{
{
"no delimiter",
userID + base58.CheckEncode(data[:], tokenVersion),
},
{
"missing userID",
tokenDelimiter + base58.CheckEncode(data[:], tokenVersion),
},
{
"not enough data",
userID + tokenDelimiter + base58.CheckEncode(data[:len(data)-10], tokenVersion),
},
{
"too much data",
userID + tokenDelimiter + base58.CheckEncode(append(data[:], []byte{0, 0, 0}...), tokenVersion),
},
{
"data checksum/format error",
userID + tokenDelimiter + base58.CheckEncode(data[:], tokenVersion)[:len(base58.CheckEncode(data[:], tokenVersion))-4] + "0000",
},
}
for _, c := range cases {
t.Run(c.testID, func(t *testing.T) {
token, err := ParseToken(c.tokenString)
assert.Nil(t, token)
assert.True(t, ErrInvalidToken.Has(err))
})
}
}
func newTestAuthDB(ctx *testcontext.Context) (*AuthorizationDB, error) {
dbPath := "bolt://" + filepath.Join(ctx.Dir(), "authorizations.db")
config := CertSignerConfig{
AuthorizationDBURL: dbPath,
}
return config.NewAuthDB()
}