138 lines
3.0 KiB
Go
138 lines
3.0 KiB
Go
|
// Copyright (C) 2022 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package oidc_test
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/require"
|
||
|
|
||
|
"storj.io/common/testcontext"
|
||
|
"storj.io/common/uuid"
|
||
|
"storj.io/storj/satellite"
|
||
|
"storj.io/storj/satellite/oidc"
|
||
|
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||
|
)
|
||
|
|
||
|
func TestOAuthCodes(t *testing.T) {
|
||
|
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||
|
clientID, err := uuid.New()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
userID, err := uuid.New()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// repositories
|
||
|
codes := db.OIDC().OAuthCodes()
|
||
|
|
||
|
start := time.Now()
|
||
|
|
||
|
allCodes := []oidc.OAuthCode{
|
||
|
{
|
||
|
ClientID: clientID,
|
||
|
UserID: userID,
|
||
|
Code: "expired",
|
||
|
CreatedAt: start.Add(-2 * time.Hour),
|
||
|
ExpiresAt: start.Add(-1 * time.Hour),
|
||
|
},
|
||
|
{
|
||
|
ClientID: clientID,
|
||
|
UserID: userID,
|
||
|
Code: "valid",
|
||
|
CreatedAt: start,
|
||
|
ExpiresAt: start.Add(time.Hour),
|
||
|
},
|
||
|
{
|
||
|
ClientID: clientID,
|
||
|
UserID: userID,
|
||
|
Code: "claimed",
|
||
|
CreatedAt: start,
|
||
|
ExpiresAt: start.Add(time.Hour),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, code := range allCodes {
|
||
|
err = codes.Create(ctx, code)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
// claim this code ahead of time to test the token already claimed code path later on
|
||
|
err = codes.Claim(ctx, "claimed")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
testCases := []struct {
|
||
|
code string
|
||
|
err error
|
||
|
}{
|
||
|
{"expired", sql.ErrNoRows},
|
||
|
{"valid", nil},
|
||
|
{"claimed", sql.ErrNoRows}, // this should return an error since it was claimed above
|
||
|
}
|
||
|
|
||
|
for _, testCase := range testCases {
|
||
|
_, err := codes.Get(ctx, testCase.code)
|
||
|
require.Equal(t, testCase.err, err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestOAuthTokens(t *testing.T) {
|
||
|
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||
|
clientID, err := uuid.New()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
userID, err := uuid.New()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// repositories
|
||
|
tokens := db.OIDC().OAuthTokens()
|
||
|
|
||
|
start := time.Now()
|
||
|
|
||
|
allTokens := []oidc.OAuthToken{
|
||
|
{
|
||
|
ClientID: clientID,
|
||
|
UserID: userID,
|
||
|
Kind: oidc.KindAccessToken,
|
||
|
Token: "expired",
|
||
|
CreatedAt: start.Add(-2 * time.Hour),
|
||
|
ExpiresAt: start.Add(-1 * time.Hour),
|
||
|
},
|
||
|
{
|
||
|
ClientID: clientID,
|
||
|
UserID: userID,
|
||
|
Kind: oidc.KindRefreshToken,
|
||
|
Token: "valid",
|
||
|
CreatedAt: start,
|
||
|
ExpiresAt: start.Add(time.Hour),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, token := range allTokens {
|
||
|
err = tokens.Create(ctx, token)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
// ensure that creating an existing token doesn't cause an error
|
||
|
err = tokens.Create(ctx, allTokens[1])
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
testCases := []struct {
|
||
|
kind oidc.OAuthTokenKind
|
||
|
token string
|
||
|
err error
|
||
|
}{
|
||
|
{oidc.KindAccessToken, "expired", sql.ErrNoRows},
|
||
|
{oidc.KindRefreshToken, "valid", nil},
|
||
|
}
|
||
|
|
||
|
for _, testCase := range testCases {
|
||
|
_, err := tokens.Get(ctx, testCase.kind, testCase.token)
|
||
|
require.Equal(t, testCase.err, err)
|
||
|
}
|
||
|
})
|
||
|
}
|