satellite/oidc: move oidc into common package
Change-Id: I77702e0e46f15a09fee315b9076638e1412836f7
This commit is contained in:
parent
95921b8b39
commit
0164682c37
@ -10,11 +10,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
)
|
||||
|
||||
func (server *Server) createOAuthClient(w http.ResponseWriter, r *http.Request) {
|
||||
oauthClient := console.OAuthClient{}
|
||||
oauthClient := oidc.OAuthClient{}
|
||||
err := json.NewDecoder(r.Body).Decode(&oauthClient)
|
||||
if err != nil {
|
||||
sendJSONError(w, "invalid json", err.Error(), http.StatusBadRequest)
|
||||
@ -31,7 +31,7 @@ func (server *Server) createOAuthClient(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
err = server.db.Console().OAuthClients().Create(r.Context(), oauthClient)
|
||||
err = server.db.OIDC().OAuthClients().Create(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to create client", err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@ -47,7 +47,7 @@ func (server *Server) updateOAuthClient(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
oauthClient := console.OAuthClient{}
|
||||
oauthClient := oidc.OAuthClient{}
|
||||
err = json.NewDecoder(r.Body).Decode(&oauthClient)
|
||||
if err != nil {
|
||||
sendJSONError(w, "invalid json", err.Error(), http.StatusBadRequest)
|
||||
@ -56,7 +56,7 @@ func (server *Server) updateOAuthClient(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
oauthClient.ID = id
|
||||
|
||||
err = server.db.Console().OAuthClients().Update(r.Context(), oauthClient)
|
||||
err = server.db.OIDC().OAuthClients().Update(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to update client", err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@ -72,7 +72,7 @@ func (server *Server) deleteOAuthClient(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
err = server.db.Console().OAuthClients().Delete(r.Context(), id)
|
||||
err = server.db.OIDC().OAuthClients().Delete(r.Context(), id)
|
||||
if err != nil {
|
||||
sendJSONError(w, "failed to delete client", err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
)
|
||||
|
||||
func TestAdminOAuthAPI(t *testing.T) {
|
||||
@ -41,8 +41,8 @@ func TestAdminOAuthAPI(t *testing.T) {
|
||||
address := sat.Admin.Admin.Listener.Addr()
|
||||
|
||||
baseURL := fmt.Sprintf("http://%s/api/oauth/clients", address)
|
||||
empty := console.OAuthClient{}
|
||||
client := console.OAuthClient{ID: id, Secret: []byte("badadmin"), UserID: userID, RedirectURL: "http://localhost:1234"}
|
||||
empty := oidc.OAuthClient{}
|
||||
client := oidc.OAuthClient{ID: id, Secret: []byte("badadmin"), UserID: userID, RedirectURL: "http://localhost:1234"}
|
||||
updated := client
|
||||
updated.RedirectURL = "http://localhost:1235"
|
||||
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"storj.io/storj/satellite/accounting"
|
||||
"storj.io/storj/satellite/buckets"
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
"storj.io/storj/satellite/payments"
|
||||
"storj.io/storj/satellite/payments/stripecoinpayments"
|
||||
)
|
||||
@ -45,6 +46,8 @@ type DB interface {
|
||||
ProjectAccounting() accounting.ProjectAccounting
|
||||
// Console returns database for satellite console
|
||||
Console() console.DB
|
||||
// OIDC returns the database for OIDC and OAuth information.
|
||||
OIDC() oidc.DB
|
||||
// StripeCoinPayments returns database for satellite stripe coin payments
|
||||
StripeCoinPayments() stripecoinpayments.DB
|
||||
}
|
||||
|
@ -19,8 +19,6 @@ type DB interface {
|
||||
ProjectMembers() ProjectMembers
|
||||
// APIKeys is a getter for APIKeys repository.
|
||||
APIKeys() APIKeys
|
||||
// OAuthClients returns an API for the OAuthClients repository.
|
||||
OAuthClients() OAuthClients
|
||||
// RegistrationTokens is a getter for RegistrationTokens repository.
|
||||
RegistrationTokens() RegistrationTokens
|
||||
// ResetPasswordTokens is a getter for ResetPasswordTokens repository.
|
||||
|
@ -1,36 +0,0 @@
|
||||
// Copyright (C) 2022 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package console
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
)
|
||||
|
||||
// OAuthClients defines an interface for creating, updating, and obtaining information about oauth clients known to our
|
||||
// system.
|
||||
type OAuthClients interface {
|
||||
// Get returns the OAuthClient associated with the provided id.
|
||||
Get(ctx context.Context, id uuid.UUID) (OAuthClient, error)
|
||||
|
||||
// Create creates a new OAuthClient.
|
||||
Create(ctx context.Context, client OAuthClient) error
|
||||
|
||||
// Update modifies information for the provided OAuthClient.
|
||||
Update(ctx context.Context, client OAuthClient) error
|
||||
|
||||
// Delete deletes the identified client from the database.
|
||||
Delete(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
// OAuthClient defines a concrete representation of an oauth client.
|
||||
type OAuthClient struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Secret []byte `json:"secret"`
|
||||
UserID uuid.UUID `json:"userID"`
|
||||
RedirectURL string `json:"redirectURL"`
|
||||
AppName string `json:"appName"`
|
||||
AppLogoURL string `json:"appLogoURL"`
|
||||
}
|
166
satellite/oidc/database.go
Normal file
166
satellite/oidc/database.go
Normal file
@ -0,0 +1,166 @@
|
||||
// Copyright (C) 2022 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/satellitedb/dbx"
|
||||
)
|
||||
|
||||
// DB defines a collection of resources that fall under the scope of OIDC and OAuth operations.
|
||||
//
|
||||
// architecture: Database
|
||||
type DB interface {
|
||||
// OAuthClients returns an API for the oauthclients repository.
|
||||
OAuthClients() OAuthClients
|
||||
// OAuthCodes returns an API for the oauthcodes repository.
|
||||
OAuthCodes() OAuthCodes
|
||||
// OAuthTokens returns an API for the oauthtokens repository.
|
||||
OAuthTokens() OAuthTokens
|
||||
}
|
||||
|
||||
// OAuthClients defines an interface for creating, updating, and obtaining information about oauth clients known to our
|
||||
// system.
|
||||
type OAuthClients interface {
|
||||
// Get returns the OAuthClient associated with the provided id.
|
||||
Get(ctx context.Context, id uuid.UUID) (OAuthClient, error)
|
||||
|
||||
// Create creates a new OAuthClient.
|
||||
Create(ctx context.Context, client OAuthClient) error
|
||||
|
||||
// Update modifies information for the provided OAuthClient.
|
||||
Update(ctx context.Context, client OAuthClient) error
|
||||
|
||||
// Delete deletes the identified client from the database.
|
||||
Delete(ctx context.Context, id uuid.UUID) error
|
||||
}
|
||||
|
||||
// OAuthClient defines a concrete representation of an oauth client.
|
||||
type OAuthClient struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Secret []byte `json:"secret"`
|
||||
UserID uuid.UUID `json:"userID"`
|
||||
RedirectURL string `json:"redirectURL"`
|
||||
AppName string `json:"appName"`
|
||||
AppLogoURL string `json:"appLogoURL"`
|
||||
}
|
||||
|
||||
// GetID returns the clients id.
|
||||
func (o OAuthClient) GetID() string {
|
||||
return o.ID.String()
|
||||
}
|
||||
|
||||
// GetSecret returns the clients secret.
|
||||
func (o OAuthClient) GetSecret() string {
|
||||
return string(o.Secret)
|
||||
}
|
||||
|
||||
// GetDomain returns the allowed redirect url associated with the client.
|
||||
func (o OAuthClient) GetDomain() string {
|
||||
return o.RedirectURL
|
||||
}
|
||||
|
||||
// GetUserID returns the owners' user id.
|
||||
func (o OAuthClient) GetUserID() string {
|
||||
return o.UserID.String()
|
||||
}
|
||||
|
||||
// OAuthCodes defines a set of operations allowed to be performed against oauth codes.
|
||||
type OAuthCodes interface {
|
||||
// Get retrieves the OAuthCode for the specified code. Implementations should only return unexpired, unclaimed
|
||||
// codes. Once a code has been claimed, it should be marked as such to prevent future calls from exchanging the
|
||||
// value for an access tokens.
|
||||
Get(ctx context.Context, code string) (OAuthCode, error)
|
||||
|
||||
// Create creates a new OAuthCode.
|
||||
Create(ctx context.Context, code OAuthCode) error
|
||||
|
||||
// Claim marks that the provided code has been claimed and should not be issued to another caller.
|
||||
Claim(ctx context.Context, code string) error
|
||||
}
|
||||
|
||||
// OAuthTokens defines a set of operations that ca be performed against oauth tokens.
|
||||
type OAuthTokens interface {
|
||||
// Get retrieves the OAuthToken for the specified kind and token value. This can be used to look up either refresh
|
||||
// or access tokens that have not expired.
|
||||
Get(ctx context.Context, kind OAuthTokenKind, token string) (OAuthToken, error)
|
||||
|
||||
// Create creates a new OAuthToken. If the token already exists, no value is modified and nil is returned.
|
||||
Create(ctx context.Context, token OAuthToken) error
|
||||
}
|
||||
|
||||
// OAuthTokenKind defines an enumeration of different types of supported tokens.
|
||||
type OAuthTokenKind int8
|
||||
|
||||
const (
|
||||
// KindUnknown is used to represent an entry for which we do not recognize the value.
|
||||
KindUnknown = 0
|
||||
// KindAccessToken represents an access token within the database.
|
||||
KindAccessToken = 1
|
||||
// KindRefreshToken represents a refresh token within the database.
|
||||
KindRefreshToken = 2
|
||||
)
|
||||
|
||||
// OAuthCode represents a code stored within our database.
|
||||
type OAuthCode struct {
|
||||
ClientID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
Scope string
|
||||
RedirectURL string
|
||||
Challenge string
|
||||
ChallengeMethod string
|
||||
Code string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
ClaimedAt *time.Time
|
||||
}
|
||||
|
||||
// OAuthToken represents a token stored within our database (either access / refresh).
|
||||
type OAuthToken struct {
|
||||
ClientID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
Scope string
|
||||
Kind OAuthTokenKind
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewDB constructs a database using the provided dbx db.
|
||||
func NewDB(dbxdb *dbx.DB) DB {
|
||||
return &db{
|
||||
clients: &clientsDBX{
|
||||
db: dbxdb,
|
||||
},
|
||||
codes: &codesDBX{
|
||||
db: dbxdb,
|
||||
},
|
||||
tokens: &tokensDBX{
|
||||
db: dbxdb,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type db struct {
|
||||
clients OAuthClients
|
||||
codes OAuthCodes
|
||||
tokens OAuthTokens
|
||||
}
|
||||
|
||||
func (d *db) OAuthClients() OAuthClients {
|
||||
return d.clients
|
||||
}
|
||||
|
||||
func (d *db) OAuthCodes() OAuthCodes {
|
||||
return d.codes
|
||||
}
|
||||
|
||||
func (d *db) OAuthTokens() OAuthTokens {
|
||||
return d.tokens
|
||||
}
|
||||
|
||||
var _ DB = &db{}
|
182
satellite/oidc/dbx.go
Normal file
182
satellite/oidc/dbx.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright (C) 2022 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/satellitedb/dbx"
|
||||
)
|
||||
|
||||
type clientsDBX struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
|
||||
// Get returns the OAuthClient associated with the provided id.
|
||||
func (clients *clientsDBX) Get(ctx context.Context, id uuid.UUID) (OAuthClient, error) {
|
||||
oauthClient, err := clients.db.Get_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
||||
if err != nil {
|
||||
return OAuthClient{}, err
|
||||
}
|
||||
|
||||
userID, err := uuid.FromBytes(oauthClient.UserId)
|
||||
if err != nil {
|
||||
return OAuthClient{}, err
|
||||
}
|
||||
|
||||
client := OAuthClient{
|
||||
ID: id,
|
||||
Secret: oauthClient.EncryptedSecret,
|
||||
UserID: userID,
|
||||
RedirectURL: oauthClient.RedirectUrl,
|
||||
AppName: oauthClient.AppName,
|
||||
AppLogoURL: oauthClient.AppLogoUrl,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Create creates a new OAuthClient.
|
||||
func (clients *clientsDBX) Create(ctx context.Context, client OAuthClient) error {
|
||||
_, err := clients.db.Create_OauthClient(ctx,
|
||||
dbx.OauthClient_Id(client.ID.Bytes()), dbx.OauthClient_EncryptedSecret(client.Secret),
|
||||
dbx.OauthClient_RedirectUrl(client.RedirectURL), dbx.OauthClient_UserId(client.UserID.Bytes()),
|
||||
dbx.OauthClient_AppName(client.AppName), dbx.OauthClient_AppLogoUrl(client.AppLogoURL))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Update modifies information for the provided OAuthClient.
|
||||
func (clients *clientsDBX) Update(ctx context.Context, client OAuthClient) error {
|
||||
if client.RedirectURL == "" && client.Secret == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
update := dbx.OauthClient_Update_Fields{}
|
||||
|
||||
if client.RedirectURL != "" {
|
||||
update.RedirectUrl = dbx.OauthClient_RedirectUrl(client.RedirectURL)
|
||||
}
|
||||
|
||||
if client.Secret != nil {
|
||||
update.EncryptedSecret = dbx.OauthClient_EncryptedSecret(client.Secret)
|
||||
}
|
||||
|
||||
err := clients.db.UpdateNoReturn_OauthClient_By_Id(ctx, dbx.OauthClient_Id(client.ID.Bytes()), update)
|
||||
return err
|
||||
}
|
||||
|
||||
func (clients *clientsDBX) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := clients.db.Delete_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
||||
return err
|
||||
}
|
||||
|
||||
type codesDBX struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
|
||||
func (o *codesDBX) Get(ctx context.Context, code string) (oauthCode OAuthCode, err error) {
|
||||
dbCode, err := o.db.Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code))
|
||||
if err != nil {
|
||||
return oauthCode, err
|
||||
}
|
||||
|
||||
clientID, err := uuid.FromBytes(dbCode.ClientId)
|
||||
if err != nil {
|
||||
return oauthCode, err
|
||||
}
|
||||
|
||||
userID, err := uuid.FromBytes(dbCode.UserId)
|
||||
if err != nil {
|
||||
return oauthCode, err
|
||||
}
|
||||
|
||||
if time.Now().After(dbCode.ExpiresAt) {
|
||||
return oauthCode, sql.ErrNoRows
|
||||
}
|
||||
|
||||
oauthCode.ClientID = clientID
|
||||
oauthCode.UserID = userID
|
||||
oauthCode.Scope = dbCode.Scope
|
||||
oauthCode.RedirectURL = dbCode.RedirectUrl
|
||||
oauthCode.Challenge = dbCode.Challenge
|
||||
oauthCode.ChallengeMethod = dbCode.ChallengeMethod
|
||||
oauthCode.Code = dbCode.Code
|
||||
oauthCode.CreatedAt = dbCode.CreatedAt
|
||||
oauthCode.ExpiresAt = dbCode.ExpiresAt
|
||||
oauthCode.ClaimedAt = dbCode.ClaimedAt
|
||||
|
||||
return oauthCode, nil
|
||||
}
|
||||
|
||||
func (o *codesDBX) Create(ctx context.Context, code OAuthCode) error {
|
||||
_, err := o.db.Create_OauthCode(ctx, dbx.OauthCode_ClientId(code.ClientID.Bytes()),
|
||||
dbx.OauthCode_UserId(code.UserID.Bytes()), dbx.OauthCode_Scope(code.Scope),
|
||||
dbx.OauthCode_RedirectUrl(code.RedirectURL), dbx.OauthCode_Challenge(code.Challenge),
|
||||
dbx.OauthCode_ChallengeMethod(code.ChallengeMethod), dbx.OauthCode_Code(code.Code),
|
||||
dbx.OauthCode_CreatedAt(code.CreatedAt), dbx.OauthCode_ExpiresAt(code.ExpiresAt), dbx.OauthCode_Create_Fields{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *codesDBX) Claim(ctx context.Context, code string) error {
|
||||
return o.db.UpdateNoReturn_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code), dbx.OauthCode_Update_Fields{
|
||||
ClaimedAt: dbx.OauthCode_ClaimedAt(time.Now()),
|
||||
})
|
||||
}
|
||||
|
||||
type tokensDBX struct {
|
||||
db *dbx.DB
|
||||
}
|
||||
|
||||
func (o *tokensDBX) Get(ctx context.Context, kind OAuthTokenKind, token string) (oauthToken OAuthToken, err error) {
|
||||
dbToken, err := o.db.Get_OauthToken_By_Kind_And_Token(ctx, dbx.OauthToken_Kind(int(kind)),
|
||||
dbx.OauthToken_Token([]byte(token)))
|
||||
|
||||
if err != nil {
|
||||
return oauthToken, err
|
||||
}
|
||||
|
||||
clientID, err := uuid.FromBytes(dbToken.ClientId)
|
||||
if err != nil {
|
||||
return oauthToken, err
|
||||
}
|
||||
|
||||
userID, err := uuid.FromBytes(dbToken.UserId)
|
||||
if err != nil {
|
||||
return oauthToken, err
|
||||
}
|
||||
|
||||
if time.Now().After(dbToken.ExpiresAt) {
|
||||
return oauthToken, sql.ErrNoRows
|
||||
}
|
||||
|
||||
oauthToken.ClientID = clientID
|
||||
oauthToken.UserID = userID
|
||||
oauthToken.Scope = dbToken.Scope
|
||||
oauthToken.Kind = OAuthTokenKind(dbToken.Kind)
|
||||
oauthToken.Token = token
|
||||
oauthToken.CreatedAt = dbToken.CreatedAt
|
||||
oauthToken.ExpiresAt = dbToken.ExpiresAt
|
||||
|
||||
return oauthToken, nil
|
||||
}
|
||||
|
||||
func (o *tokensDBX) Create(ctx context.Context, token OAuthToken) error {
|
||||
_, err := o.db.Create_OauthToken(ctx, dbx.OauthToken_ClientId(token.ClientID.Bytes()),
|
||||
dbx.OauthToken_UserId(token.UserID.Bytes()), dbx.OauthToken_Scope(token.Scope),
|
||||
dbx.OauthToken_Kind(int(token.Kind)), dbx.OauthToken_Token([]byte(token.Token)),
|
||||
dbx.OauthToken_CreatedAt(token.CreatedAt), dbx.OauthToken_ExpiresAt(token.ExpiresAt))
|
||||
|
||||
// ignore duplicate key errors as they're somewhat expected
|
||||
if err != nil && strings.Contains(err.Error(), "duplicate key") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
137
satellite/oidc/dbx_test.go
Normal file
137
satellite/oidc/dbx_test.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Copyright (C) 2022 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||||
)
|
||||
|
||||
func TestOAuthCodes(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||||
clientID, err := uuid.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, err := uuid.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
// repositories
|
||||
codes := db.OIDC().OAuthCodes()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
allCodes := []oidc.OAuthCode{
|
||||
{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Code: "expired",
|
||||
CreatedAt: start.Add(-2 * time.Hour),
|
||||
ExpiresAt: start.Add(-1 * time.Hour),
|
||||
},
|
||||
{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Code: "valid",
|
||||
CreatedAt: start,
|
||||
ExpiresAt: start.Add(time.Hour),
|
||||
},
|
||||
{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Code: "claimed",
|
||||
CreatedAt: start,
|
||||
ExpiresAt: start.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, code := range allCodes {
|
||||
err = codes.Create(ctx, code)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// claim this code ahead of time to test the token already claimed code path later on
|
||||
err = codes.Claim(ctx, "claimed")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
code string
|
||||
err error
|
||||
}{
|
||||
{"expired", sql.ErrNoRows},
|
||||
{"valid", nil},
|
||||
{"claimed", sql.ErrNoRows}, // this should return an error since it was claimed above
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
_, err := codes.Get(ctx, testCase.code)
|
||||
require.Equal(t, testCase.err, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOAuthTokens(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||||
clientID, err := uuid.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, err := uuid.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
// repositories
|
||||
tokens := db.OIDC().OAuthTokens()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
allTokens := []oidc.OAuthToken{
|
||||
{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Kind: oidc.KindAccessToken,
|
||||
Token: "expired",
|
||||
CreatedAt: start.Add(-2 * time.Hour),
|
||||
ExpiresAt: start.Add(-1 * time.Hour),
|
||||
},
|
||||
{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Kind: oidc.KindRefreshToken,
|
||||
Token: "valid",
|
||||
CreatedAt: start,
|
||||
ExpiresAt: start.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, token := range allTokens {
|
||||
err = tokens.Create(ctx, token)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// ensure that creating an existing token doesn't cause an error
|
||||
err = tokens.Create(ctx, allTokens[1])
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
kind oidc.OAuthTokenKind
|
||||
token string
|
||||
err error
|
||||
}{
|
||||
{oidc.KindAccessToken, "expired", sql.ErrNoRows},
|
||||
{oidc.KindRefreshToken, "valid", nil},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
_, err := tokens.Get(ctx, testCase.kind, testCase.token)
|
||||
require.Equal(t, testCase.err, err)
|
||||
}
|
||||
})
|
||||
}
|
@ -36,6 +36,7 @@ import (
|
||||
"storj.io/storj/satellite/metainfo/expireddeletion"
|
||||
"storj.io/storj/satellite/metrics"
|
||||
"storj.io/storj/satellite/nodeapiversion"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
"storj.io/storj/satellite/orders"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/overlay/straynodes"
|
||||
@ -85,6 +86,8 @@ type DB interface {
|
||||
RepairQueue() queue.RepairQueue
|
||||
// Console returns database for satellite console
|
||||
Console() console.DB
|
||||
// OIDC returns the database for OIDC resources.
|
||||
OIDC() oidc.DB
|
||||
// Orders returns database for orders
|
||||
Orders() orders.DB
|
||||
// Containment returns database for containment
|
||||
|
@ -58,11 +58,6 @@ func (db *ConsoleDB) APIKeys() console.APIKeys {
|
||||
return db.apikeys
|
||||
}
|
||||
|
||||
// OAuthClients returns an API for the OAuthClients repository.
|
||||
func (db *ConsoleDB) OAuthClients() console.OAuthClients {
|
||||
return &oauthClients{methods: db.methods, db: db.db}
|
||||
}
|
||||
|
||||
// RegistrationTokens is a getter for RegistrationTokens repository.
|
||||
func (db *ConsoleDB) RegistrationTokens() console.RegistrationTokens {
|
||||
return ®istrationTokens{db.methods}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/gracefulexit"
|
||||
"storj.io/storj/satellite/nodeapiversion"
|
||||
"storj.io/storj/satellite/oidc"
|
||||
"storj.io/storj/satellite/orders"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/payments/stripecoinpayments"
|
||||
@ -236,6 +237,12 @@ func (dbc *satelliteDBCollection) Console() console.DB {
|
||||
return db.consoleDB
|
||||
}
|
||||
|
||||
// OIDC returns the database for storing OAuth and OIDC information.
|
||||
func (dbc *satelliteDBCollection) OIDC() oidc.DB {
|
||||
db := dbc.getByName("oidc")
|
||||
return oidc.NewDB(db.DB)
|
||||
}
|
||||
|
||||
// Orders returns database for storing orders.
|
||||
func (dbc *satelliteDBCollection) Orders() orders.DB {
|
||||
db := dbc.getByName("orders")
|
||||
|
@ -1368,6 +1368,7 @@ create oauth_code ()
|
||||
read one (
|
||||
select oauth_code
|
||||
where oauth_code.code = ?
|
||||
where oauth_code.claimed_at = null
|
||||
)
|
||||
|
||||
update oauth_code (
|
||||
|
@ -14650,12 +14650,12 @@ func (obj *pgxImpl) Get_OauthClient_By_Id(ctx context.Context,
|
||||
|
||||
}
|
||||
|
||||
func (obj *pgxImpl) Get_OauthCode_By_Code(ctx context.Context,
|
||||
func (obj *pgxImpl) Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx context.Context,
|
||||
oauth_code_code OauthCode_Code_Field) (
|
||||
oauth_code *OauthCode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var __embed_stmt = __sqlbundle_Literal("SELECT oauth_codes.client_id, oauth_codes.user_id, oauth_codes.scope, oauth_codes.redirect_url, oauth_codes.challenge, oauth_codes.challenge_method, oauth_codes.code, oauth_codes.created_at, oauth_codes.expires_at, oauth_codes.claimed_at FROM oauth_codes WHERE oauth_codes.code = ?")
|
||||
var __embed_stmt = __sqlbundle_Literal("SELECT oauth_codes.client_id, oauth_codes.user_id, oauth_codes.scope, oauth_codes.redirect_url, oauth_codes.challenge, oauth_codes.challenge_method, oauth_codes.code, oauth_codes.created_at, oauth_codes.expires_at, oauth_codes.claimed_at FROM oauth_codes WHERE oauth_codes.code = ? AND oauth_codes.claimed_at is NULL")
|
||||
|
||||
var __values []interface{}
|
||||
__values = append(__values, oauth_code_code.value())
|
||||
@ -21019,12 +21019,12 @@ func (obj *pgxcockroachImpl) Get_OauthClient_By_Id(ctx context.Context,
|
||||
|
||||
}
|
||||
|
||||
func (obj *pgxcockroachImpl) Get_OauthCode_By_Code(ctx context.Context,
|
||||
func (obj *pgxcockroachImpl) Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx context.Context,
|
||||
oauth_code_code OauthCode_Code_Field) (
|
||||
oauth_code *OauthCode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var __embed_stmt = __sqlbundle_Literal("SELECT oauth_codes.client_id, oauth_codes.user_id, oauth_codes.scope, oauth_codes.redirect_url, oauth_codes.challenge, oauth_codes.challenge_method, oauth_codes.code, oauth_codes.created_at, oauth_codes.expires_at, oauth_codes.claimed_at FROM oauth_codes WHERE oauth_codes.code = ?")
|
||||
var __embed_stmt = __sqlbundle_Literal("SELECT oauth_codes.client_id, oauth_codes.user_id, oauth_codes.scope, oauth_codes.redirect_url, oauth_codes.challenge, oauth_codes.challenge_method, oauth_codes.code, oauth_codes.created_at, oauth_codes.expires_at, oauth_codes.claimed_at FROM oauth_codes WHERE oauth_codes.code = ? AND oauth_codes.claimed_at is NULL")
|
||||
|
||||
var __values []interface{}
|
||||
__values = append(__values, oauth_code_code.value())
|
||||
@ -24544,14 +24544,14 @@ func (rx *Rx) Get_OauthClient_By_Id(ctx context.Context,
|
||||
return tx.Get_OauthClient_By_Id(ctx, oauth_client_id)
|
||||
}
|
||||
|
||||
func (rx *Rx) Get_OauthCode_By_Code(ctx context.Context,
|
||||
func (rx *Rx) Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx context.Context,
|
||||
oauth_code_code OauthCode_Code_Field) (
|
||||
oauth_code *OauthCode, err error) {
|
||||
var tx *Tx
|
||||
if tx, err = rx.getTx(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
return tx.Get_OauthCode_By_Code(ctx, oauth_code_code)
|
||||
return tx.Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, oauth_code_code)
|
||||
}
|
||||
|
||||
func (rx *Rx) Get_OauthToken_By_Kind_And_Token(ctx context.Context,
|
||||
@ -25690,7 +25690,7 @@ type Methods interface {
|
||||
oauth_client_id OauthClient_Id_Field) (
|
||||
oauth_client *OauthClient, err error)
|
||||
|
||||
Get_OauthCode_By_Code(ctx context.Context,
|
||||
Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx context.Context,
|
||||
oauth_code_code OauthCode_Code_Field) (
|
||||
oauth_code *OauthCode, err error)
|
||||
|
||||
|
@ -1,76 +0,0 @@
|
||||
// Copyright (C) 2022 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package satellitedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/console"
|
||||
"storj.io/storj/satellite/satellitedb/dbx"
|
||||
)
|
||||
|
||||
type oauthClients struct {
|
||||
methods dbx.Methods
|
||||
db *satelliteDB
|
||||
}
|
||||
|
||||
// Get returns the OAuthClient associated with the provided id.
|
||||
func (clients *oauthClients) Get(ctx context.Context, id uuid.UUID) (console.OAuthClient, error) {
|
||||
oauthClient, err := clients.db.Get_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
||||
if err != nil {
|
||||
return console.OAuthClient{}, err
|
||||
}
|
||||
|
||||
userID, err := uuid.FromBytes(oauthClient.UserId)
|
||||
if err != nil {
|
||||
return console.OAuthClient{}, err
|
||||
}
|
||||
|
||||
client := console.OAuthClient{
|
||||
ID: id,
|
||||
Secret: oauthClient.EncryptedSecret,
|
||||
UserID: userID,
|
||||
RedirectURL: oauthClient.RedirectUrl,
|
||||
AppName: oauthClient.AppName,
|
||||
AppLogoURL: oauthClient.AppLogoUrl,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Create creates a new OAuthClient.
|
||||
func (clients *oauthClients) Create(ctx context.Context, client console.OAuthClient) error {
|
||||
_, err := clients.db.Create_OauthClient(ctx,
|
||||
dbx.OauthClient_Id(client.ID.Bytes()), dbx.OauthClient_EncryptedSecret(client.Secret),
|
||||
dbx.OauthClient_RedirectUrl(client.RedirectURL), dbx.OauthClient_UserId(client.UserID.Bytes()),
|
||||
dbx.OauthClient_AppName(client.AppName), dbx.OauthClient_AppLogoUrl(client.AppLogoURL))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Update modifies information for the provided OAuthClient.
|
||||
func (clients *oauthClients) Update(ctx context.Context, client console.OAuthClient) error {
|
||||
if client.RedirectURL == "" && client.Secret == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
update := dbx.OauthClient_Update_Fields{}
|
||||
|
||||
if client.RedirectURL != "" {
|
||||
update.RedirectUrl = dbx.OauthClient_RedirectUrl(client.RedirectURL)
|
||||
}
|
||||
|
||||
if client.Secret != nil {
|
||||
update.EncryptedSecret = dbx.OauthClient_EncryptedSecret(client.Secret)
|
||||
}
|
||||
|
||||
err := clients.db.UpdateNoReturn_OauthClient_By_Id(ctx, dbx.OauthClient_Id(client.ID.Bytes()), update)
|
||||
return err
|
||||
}
|
||||
|
||||
func (clients *oauthClients) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := clients.db.Delete_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue
Block a user