2022-01-19 18:25:31 +00:00
|
|
|
// Copyright (C) 2022 Storj Labs, Inc.
|
|
|
|
// See LICENSE for copying information.
|
|
|
|
|
|
|
|
package oidc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"storj.io/common/uuid"
|
|
|
|
"storj.io/storj/satellite/satellitedb/dbx"
|
|
|
|
)
|
|
|
|
|
|
|
|
type clientsDBX struct {
|
|
|
|
db *dbx.DB
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get returns the OAuthClient associated with the provided id.
|
|
|
|
func (clients *clientsDBX) Get(ctx context.Context, id uuid.UUID) (OAuthClient, error) {
|
|
|
|
oauthClient, err := clients.db.Get_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
|
|
|
if err != nil {
|
|
|
|
return OAuthClient{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
userID, err := uuid.FromBytes(oauthClient.UserId)
|
|
|
|
if err != nil {
|
|
|
|
return OAuthClient{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
client := OAuthClient{
|
|
|
|
ID: id,
|
|
|
|
Secret: oauthClient.EncryptedSecret,
|
|
|
|
UserID: userID,
|
|
|
|
RedirectURL: oauthClient.RedirectUrl,
|
|
|
|
AppName: oauthClient.AppName,
|
|
|
|
AppLogoURL: oauthClient.AppLogoUrl,
|
|
|
|
}
|
|
|
|
|
|
|
|
return client, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create creates a new OAuthClient.
|
|
|
|
func (clients *clientsDBX) Create(ctx context.Context, client OAuthClient) error {
|
|
|
|
_, err := clients.db.Create_OauthClient(ctx,
|
|
|
|
dbx.OauthClient_Id(client.ID.Bytes()), dbx.OauthClient_EncryptedSecret(client.Secret),
|
|
|
|
dbx.OauthClient_RedirectUrl(client.RedirectURL), dbx.OauthClient_UserId(client.UserID.Bytes()),
|
|
|
|
dbx.OauthClient_AppName(client.AppName), dbx.OauthClient_AppLogoUrl(client.AppLogoURL))
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update modifies information for the provided OAuthClient.
|
|
|
|
func (clients *clientsDBX) Update(ctx context.Context, client OAuthClient) error {
|
|
|
|
if client.RedirectURL == "" && client.Secret == nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
update := dbx.OauthClient_Update_Fields{}
|
|
|
|
|
|
|
|
if client.RedirectURL != "" {
|
|
|
|
update.RedirectUrl = dbx.OauthClient_RedirectUrl(client.RedirectURL)
|
|
|
|
}
|
|
|
|
|
|
|
|
if client.Secret != nil {
|
|
|
|
update.EncryptedSecret = dbx.OauthClient_EncryptedSecret(client.Secret)
|
|
|
|
}
|
|
|
|
|
|
|
|
err := clients.db.UpdateNoReturn_OauthClient_By_Id(ctx, dbx.OauthClient_Id(client.ID.Bytes()), update)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (clients *clientsDBX) Delete(ctx context.Context, id uuid.UUID) error {
|
|
|
|
_, err := clients.db.Delete_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
type codesDBX struct {
|
|
|
|
db *dbx.DB
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *codesDBX) Get(ctx context.Context, code string) (oauthCode OAuthCode, err error) {
|
|
|
|
dbCode, err := o.db.Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code))
|
|
|
|
if err != nil {
|
|
|
|
return oauthCode, err
|
|
|
|
}
|
|
|
|
|
|
|
|
clientID, err := uuid.FromBytes(dbCode.ClientId)
|
|
|
|
if err != nil {
|
|
|
|
return oauthCode, err
|
|
|
|
}
|
|
|
|
|
|
|
|
userID, err := uuid.FromBytes(dbCode.UserId)
|
|
|
|
if err != nil {
|
|
|
|
return oauthCode, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if time.Now().After(dbCode.ExpiresAt) {
|
|
|
|
return oauthCode, sql.ErrNoRows
|
|
|
|
}
|
|
|
|
|
|
|
|
oauthCode.ClientID = clientID
|
|
|
|
oauthCode.UserID = userID
|
|
|
|
oauthCode.Scope = dbCode.Scope
|
|
|
|
oauthCode.RedirectURL = dbCode.RedirectUrl
|
|
|
|
oauthCode.Challenge = dbCode.Challenge
|
|
|
|
oauthCode.ChallengeMethod = dbCode.ChallengeMethod
|
|
|
|
oauthCode.Code = dbCode.Code
|
|
|
|
oauthCode.CreatedAt = dbCode.CreatedAt
|
|
|
|
oauthCode.ExpiresAt = dbCode.ExpiresAt
|
|
|
|
oauthCode.ClaimedAt = dbCode.ClaimedAt
|
|
|
|
|
|
|
|
return oauthCode, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *codesDBX) Create(ctx context.Context, code OAuthCode) error {
|
|
|
|
_, err := o.db.Create_OauthCode(ctx, dbx.OauthCode_ClientId(code.ClientID.Bytes()),
|
|
|
|
dbx.OauthCode_UserId(code.UserID.Bytes()), dbx.OauthCode_Scope(code.Scope),
|
|
|
|
dbx.OauthCode_RedirectUrl(code.RedirectURL), dbx.OauthCode_Challenge(code.Challenge),
|
|
|
|
dbx.OauthCode_ChallengeMethod(code.ChallengeMethod), dbx.OauthCode_Code(code.Code),
|
|
|
|
dbx.OauthCode_CreatedAt(code.CreatedAt), dbx.OauthCode_ExpiresAt(code.ExpiresAt), dbx.OauthCode_Create_Fields{})
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *codesDBX) Claim(ctx context.Context, code string) error {
|
|
|
|
return o.db.UpdateNoReturn_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code), dbx.OauthCode_Update_Fields{
|
|
|
|
ClaimedAt: dbx.OauthCode_ClaimedAt(time.Now()),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
type tokensDBX struct {
|
|
|
|
db *dbx.DB
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *tokensDBX) Get(ctx context.Context, kind OAuthTokenKind, token string) (oauthToken OAuthToken, err error) {
|
|
|
|
dbToken, err := o.db.Get_OauthToken_By_Kind_And_Token(ctx, dbx.OauthToken_Kind(int(kind)),
|
|
|
|
dbx.OauthToken_Token([]byte(token)))
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return oauthToken, err
|
|
|
|
}
|
|
|
|
|
|
|
|
clientID, err := uuid.FromBytes(dbToken.ClientId)
|
|
|
|
if err != nil {
|
|
|
|
return oauthToken, err
|
|
|
|
}
|
|
|
|
|
|
|
|
userID, err := uuid.FromBytes(dbToken.UserId)
|
|
|
|
if err != nil {
|
|
|
|
return oauthToken, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if time.Now().After(dbToken.ExpiresAt) {
|
|
|
|
return oauthToken, sql.ErrNoRows
|
|
|
|
}
|
|
|
|
|
|
|
|
oauthToken.ClientID = clientID
|
|
|
|
oauthToken.UserID = userID
|
|
|
|
oauthToken.Scope = dbToken.Scope
|
|
|
|
oauthToken.Kind = OAuthTokenKind(dbToken.Kind)
|
|
|
|
oauthToken.Token = token
|
|
|
|
oauthToken.CreatedAt = dbToken.CreatedAt
|
|
|
|
oauthToken.ExpiresAt = dbToken.ExpiresAt
|
|
|
|
|
|
|
|
return oauthToken, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *tokensDBX) Create(ctx context.Context, token OAuthToken) error {
|
|
|
|
_, err := o.db.Create_OauthToken(ctx, dbx.OauthToken_ClientId(token.ClientID.Bytes()),
|
|
|
|
dbx.OauthToken_UserId(token.UserID.Bytes()), dbx.OauthToken_Scope(token.Scope),
|
|
|
|
dbx.OauthToken_Kind(int(token.Kind)), dbx.OauthToken_Token([]byte(token.Token)),
|
|
|
|
dbx.OauthToken_CreatedAt(token.CreatedAt), dbx.OauthToken_ExpiresAt(token.ExpiresAt))
|
|
|
|
|
|
|
|
// ignore duplicate key errors as they're somewhat expected
|
|
|
|
if err != nil && strings.Contains(err.Error(), "duplicate key") {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
2022-02-11 22:48:35 +00:00
|
|
|
|
2022-04-12 17:59:07 +01:00
|
|
|
// RevokeRESTTokenV0 revokes a v0 REST token by setting its expires_at time to zero.
|
|
|
|
func (o *tokensDBX) RevokeRESTTokenV0(ctx context.Context, token string) error {
|
2022-02-11 22:48:35 +00:00
|
|
|
return o.db.UpdateNoReturn_OauthToken_By_Token_And_Kind(ctx, dbx.OauthToken_Token([]byte(token)),
|
2022-04-12 17:59:07 +01:00
|
|
|
dbx.OauthToken_Kind(int(KindRESTTokenV0)),
|
2022-02-11 22:48:35 +00:00
|
|
|
dbx.OauthToken_Update_Fields{
|
|
|
|
ExpiresAt: dbx.OauthToken_ExpiresAt(time.Time{}),
|
|
|
|
})
|
|
|
|
}
|