storj/satellite/oidc/dbx_test.go

154 lines
3.5 KiB
Go
Raw Normal View History

// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.
package oidc_test
import (
"database/sql"
"testing"
"time"
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/common/uuid"
"storj.io/storj/satellite"
"storj.io/storj/satellite/oidc"
"storj.io/storj/satellite/satellitedb/satellitedbtest"
)
func TestOAuthCodes(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
clientID, err := uuid.New()
require.NoError(t, err)
userID, err := uuid.New()
require.NoError(t, err)
// repositories
codes := db.OIDC().OAuthCodes()
start := time.Now()
allCodes := []oidc.OAuthCode{
{
ClientID: clientID,
UserID: userID,
Code: "expired",
CreatedAt: start.Add(-2 * time.Hour),
ExpiresAt: start.Add(-1 * time.Hour),
},
{
ClientID: clientID,
UserID: userID,
Code: "valid",
CreatedAt: start,
ExpiresAt: start.Add(time.Hour),
},
{
ClientID: clientID,
UserID: userID,
Code: "claimed",
CreatedAt: start,
ExpiresAt: start.Add(time.Hour),
},
}
for _, code := range allCodes {
err = codes.Create(ctx, code)
require.NoError(t, err)
}
// claim this code ahead of time to test the token already claimed code path later on
err = codes.Claim(ctx, "claimed")
require.NoError(t, err)
testCases := []struct {
code string
err error
}{
{"expired", sql.ErrNoRows},
{"valid", nil},
{"claimed", sql.ErrNoRows}, // this should return an error since it was claimed above
}
for _, testCase := range testCases {
_, err := codes.Get(ctx, testCase.code)
require.Equal(t, testCase.err, err)
}
})
}
func TestOAuthTokens(t *testing.T) {
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
clientID, err := uuid.New()
require.NoError(t, err)
userID, err := uuid.New()
require.NoError(t, err)
// repositories
tokens := db.OIDC().OAuthTokens()
start := time.Now()
allTokens := []oidc.OAuthToken{
{
ClientID: clientID,
UserID: userID,
Kind: oidc.KindAccessToken,
Token: "expired",
CreatedAt: start.Add(-2 * time.Hour),
ExpiresAt: start.Add(-1 * time.Hour),
},
{
ClientID: clientID,
UserID: userID,
Kind: oidc.KindRefreshToken,
Token: "valid",
CreatedAt: start,
ExpiresAt: start.Add(time.Hour),
},
{
ClientID: clientID,
UserID: userID,
Kind: oidc.KindRESTTokenV0,
Token: "testToken",
CreatedAt: start,
ExpiresAt: start.Add(time.Hour),
},
}
for _, token := range allTokens {
err = tokens.Create(ctx, token)
require.NoError(t, err)
}
// ensure that creating an existing token doesn't cause an error
err = tokens.Create(ctx, allTokens[1])
require.NoError(t, err)
testCases := []struct {
kind oidc.OAuthTokenKind
token string
err error
}{
{oidc.KindAccessToken, "expired", sql.ErrNoRows},
{oidc.KindRefreshToken, "valid", nil},
{oidc.KindRESTTokenV0, "testToken", nil},
}
for _, testCase := range testCases {
_, err := tokens.Get(ctx, testCase.kind, testCase.token)
require.Equal(t, testCase.err, err)
if testCase.kind == oidc.KindRESTTokenV0 {
err = tokens.RevokeRESTTokenV0(ctx, testCase.token)
require.NoError(t, err)
token, err := tokens.Get(ctx, testCase.kind, testCase.token)
require.Equal(t, sql.ErrNoRows, err)
require.True(t, token.ExpiresAt.IsZero())
}
}
})
}