storj/satellite/oidc/dbx.go

192 lines
5.6 KiB
Go
Raw Normal View History

// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.
package oidc
import (
"context"
"database/sql"
"strings"
"time"
"storj.io/common/uuid"
"storj.io/storj/satellite/satellitedb/dbx"
)
type clientsDBX struct {
db *dbx.DB
}
// Get returns the OAuthClient associated with the provided id.
func (clients *clientsDBX) Get(ctx context.Context, id uuid.UUID) (OAuthClient, error) {
oauthClient, err := clients.db.Get_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
if err != nil {
return OAuthClient{}, err
}
userID, err := uuid.FromBytes(oauthClient.UserId)
if err != nil {
return OAuthClient{}, err
}
client := OAuthClient{
ID: id,
Secret: oauthClient.EncryptedSecret,
UserID: userID,
RedirectURL: oauthClient.RedirectUrl,
AppName: oauthClient.AppName,
AppLogoURL: oauthClient.AppLogoUrl,
}
return client, nil
}
// Create creates a new OAuthClient.
func (clients *clientsDBX) Create(ctx context.Context, client OAuthClient) error {
_, err := clients.db.Create_OauthClient(ctx,
dbx.OauthClient_Id(client.ID.Bytes()), dbx.OauthClient_EncryptedSecret(client.Secret),
dbx.OauthClient_RedirectUrl(client.RedirectURL), dbx.OauthClient_UserId(client.UserID.Bytes()),
dbx.OauthClient_AppName(client.AppName), dbx.OauthClient_AppLogoUrl(client.AppLogoURL))
return err
}
// Update modifies information for the provided OAuthClient.
func (clients *clientsDBX) Update(ctx context.Context, client OAuthClient) error {
if client.RedirectURL == "" && client.Secret == nil {
return nil
}
update := dbx.OauthClient_Update_Fields{}
if client.RedirectURL != "" {
update.RedirectUrl = dbx.OauthClient_RedirectUrl(client.RedirectURL)
}
if client.Secret != nil {
update.EncryptedSecret = dbx.OauthClient_EncryptedSecret(client.Secret)
}
err := clients.db.UpdateNoReturn_OauthClient_By_Id(ctx, dbx.OauthClient_Id(client.ID.Bytes()), update)
return err
}
func (clients *clientsDBX) Delete(ctx context.Context, id uuid.UUID) error {
_, err := clients.db.Delete_OauthClient_By_Id(ctx, dbx.OauthClient_Id(id.Bytes()))
return err
}
type codesDBX struct {
db *dbx.DB
}
func (o *codesDBX) Get(ctx context.Context, code string) (oauthCode OAuthCode, err error) {
dbCode, err := o.db.Get_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code))
if err != nil {
return oauthCode, err
}
clientID, err := uuid.FromBytes(dbCode.ClientId)
if err != nil {
return oauthCode, err
}
userID, err := uuid.FromBytes(dbCode.UserId)
if err != nil {
return oauthCode, err
}
if time.Now().After(dbCode.ExpiresAt) {
return oauthCode, sql.ErrNoRows
}
oauthCode.ClientID = clientID
oauthCode.UserID = userID
oauthCode.Scope = dbCode.Scope
oauthCode.RedirectURL = dbCode.RedirectUrl
oauthCode.Challenge = dbCode.Challenge
oauthCode.ChallengeMethod = dbCode.ChallengeMethod
oauthCode.Code = dbCode.Code
oauthCode.CreatedAt = dbCode.CreatedAt
oauthCode.ExpiresAt = dbCode.ExpiresAt
oauthCode.ClaimedAt = dbCode.ClaimedAt
return oauthCode, nil
}
func (o *codesDBX) Create(ctx context.Context, code OAuthCode) error {
_, err := o.db.Create_OauthCode(ctx, dbx.OauthCode_ClientId(code.ClientID.Bytes()),
dbx.OauthCode_UserId(code.UserID.Bytes()), dbx.OauthCode_Scope(code.Scope),
dbx.OauthCode_RedirectUrl(code.RedirectURL), dbx.OauthCode_Challenge(code.Challenge),
dbx.OauthCode_ChallengeMethod(code.ChallengeMethod), dbx.OauthCode_Code(code.Code),
dbx.OauthCode_CreatedAt(code.CreatedAt), dbx.OauthCode_ExpiresAt(code.ExpiresAt), dbx.OauthCode_Create_Fields{})
return err
}
func (o *codesDBX) Claim(ctx context.Context, code string) error {
return o.db.UpdateNoReturn_OauthCode_By_Code_And_ClaimedAt_Is_Null(ctx, dbx.OauthCode_Code(code), dbx.OauthCode_Update_Fields{
ClaimedAt: dbx.OauthCode_ClaimedAt(time.Now()),
})
}
type tokensDBX struct {
db *dbx.DB
}
func (o *tokensDBX) Get(ctx context.Context, kind OAuthTokenKind, token string) (oauthToken OAuthToken, err error) {
dbToken, err := o.db.Get_OauthToken_By_Kind_And_Token(ctx, dbx.OauthToken_Kind(int(kind)),
dbx.OauthToken_Token([]byte(token)))
if err != nil {
return oauthToken, err
}
clientID, err := uuid.FromBytes(dbToken.ClientId)
if err != nil {
return oauthToken, err
}
userID, err := uuid.FromBytes(dbToken.UserId)
if err != nil {
return oauthToken, err
}
if time.Now().After(dbToken.ExpiresAt) {
return oauthToken, sql.ErrNoRows
}
oauthToken.ClientID = clientID
oauthToken.UserID = userID
oauthToken.Scope = dbToken.Scope
oauthToken.Kind = OAuthTokenKind(dbToken.Kind)
oauthToken.Token = token
oauthToken.CreatedAt = dbToken.CreatedAt
oauthToken.ExpiresAt = dbToken.ExpiresAt
return oauthToken, nil
}
func (o *tokensDBX) Create(ctx context.Context, token OAuthToken) error {
_, err := o.db.Create_OauthToken(ctx, dbx.OauthToken_ClientId(token.ClientID.Bytes()),
dbx.OauthToken_UserId(token.UserID.Bytes()), dbx.OauthToken_Scope(token.Scope),
dbx.OauthToken_Kind(int(token.Kind)), dbx.OauthToken_Token([]byte(token.Token)),
dbx.OauthToken_CreatedAt(token.CreatedAt), dbx.OauthToken_ExpiresAt(token.ExpiresAt))
// ignore duplicate key errors as they're somewhat expected
if err != nil && strings.Contains(err.Error(), "duplicate key") {
return nil
}
return err
}
// RevokeRESTTokenV0 revokes a v0 REST token by setting its expires_at time to zero.
func (o *tokensDBX) RevokeRESTTokenV0(ctx context.Context, token string) error {
return o.db.UpdateNoReturn_OauthToken_By_Token_And_Kind(ctx, dbx.OauthToken_Token([]byte(token)),
dbx.OauthToken_Kind(int(KindRESTTokenV0)),
dbx.OauthToken_Update_Fields{
ExpiresAt: dbx.OauthToken_ExpiresAt(time.Time{}),
})
}