storj/satellite/oidc/oauth_stores.go
Mya 4a110b266e satellite/console: added oidc endpoints
This change adds endpoints for supporting OpenID Connect (OIDC) and
OAuth requests. This allows application developers to easily
develop apps with Storj using common mechanisms for authentication
and authorization.

Change-Id: I2a76d48bd1241367aa2d1e3309f6f65d6d6ea4dc
2022-03-16 12:01:26 +00:00

326 lines
7.0 KiB
Go

// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.
package oidc
import (
"context"
"time"
"github.com/go-oauth2/oauth2/v4"
"storj.io/common/uuid"
)
// clientStore provides a simple adapter for the oauth implementation.
type clientStore struct {
clients OAuthClients
}
func (c *clientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
uid, err := uuid.FromString(id)
if err != nil {
return nil, err
}
return c.clients.Get(ctx, uid)
}
// tokenStore provides a simple adapter for the oauth implementation.
type tokenStore struct {
codes OAuthCodes
tokens OAuthTokens
}
func (t *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) (err error) {
var code OAuthCode
var access, refresh OAuthToken
if r, ok := info.(*record); ok {
code = r.code
access = r.access
refresh = r.refresh
} else {
clientID, err := uuid.FromString(info.GetClientID())
if err != nil {
return err
}
userID, err := uuid.FromString(info.GetUserID())
if err != nil {
return err
}
if c := info.GetCode(); c != "" {
code.ClientID = clientID
code.UserID = userID
code.Scope = info.GetScope()
code.RedirectURL = info.GetRedirectURI()
code.Challenge = info.GetCodeChallenge()
code.ChallengeMethod = string(info.GetCodeChallengeMethod())
code.Code = c
code.CreatedAt = info.GetCodeCreateAt()
code.ExpiresAt = code.CreatedAt.Add(info.GetCodeExpiresIn())
}
if a := info.GetAccess(); a != "" {
access.ClientID = clientID
access.UserID = userID
access.Scope = info.GetScope()
access.Kind = KindAccessToken
access.Token = a
access.CreatedAt = info.GetAccessCreateAt()
access.ExpiresAt = access.CreatedAt.Add(info.GetAccessExpiresIn())
}
if r := info.GetRefresh(); r != "" {
refresh.ClientID = clientID
refresh.UserID = userID
refresh.Scope = info.GetScope()
refresh.Kind = KindRefreshToken
refresh.Token = r
refresh.CreatedAt = info.GetRefreshCreateAt()
refresh.ExpiresAt = refresh.CreatedAt.Add(info.GetRefreshExpiresIn())
}
}
if code.Code != "" {
err := t.codes.Create(ctx, code)
if err != nil {
return err
}
}
if access.Token != "" {
err := t.tokens.Create(ctx, access)
if err != nil {
return err
}
}
if refresh.Token != "" {
err := t.tokens.Create(ctx, refresh)
if err != nil {
return err
}
}
return nil
}
func (t *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return t.codes.Claim(ctx, code)
}
func (t *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return nil // unsupported by current configuration
}
func (t *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return nil // unsupported by current configuration
}
func (t *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
oauthCode, err := t.codes.Get(ctx, code)
if err != nil {
return nil, err
}
return &record{code: oauthCode}, nil
}
func (t *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
oauthToken, err := t.tokens.Get(ctx, KindAccessToken, access)
if err != nil {
return nil, err
}
return &record{access: oauthToken}, nil
}
func (t *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
oauthToken, err := t.tokens.Get(ctx, KindRefreshToken, refresh)
if err != nil {
return nil, err
}
return &record{refresh: oauthToken}, nil
}
type record struct {
code OAuthCode
access OAuthToken
refresh OAuthToken
}
func (r *record) New() oauth2.TokenInfo {
return &record{}
}
func (r *record) GetClientID() string {
switch {
case !r.code.ClientID.IsZero():
return r.code.ClientID.String()
case !r.access.ClientID.IsZero():
return r.access.ClientID.String()
case !r.refresh.ClientID.IsZero():
return r.refresh.ClientID.String()
}
return ""
}
func (r *record) SetClientID(s string) {
clientID, err := uuid.FromString(s)
if err != nil {
return
}
r.code.ClientID = clientID
r.access.ClientID = clientID
r.refresh.ClientID = clientID
}
func (r *record) GetUserID() string {
switch {
case !r.code.UserID.IsZero():
return r.code.UserID.String()
case !r.access.UserID.IsZero():
return r.access.UserID.String()
case !r.refresh.UserID.IsZero():
return r.refresh.UserID.String()
}
return ""
}
func (r *record) SetUserID(s string) {
userID, err := uuid.FromString(s)
if err != nil {
return
}
r.code.ClientID = userID
r.access.ClientID = userID
r.refresh.ClientID = userID
}
func (r *record) GetScope() string {
switch {
case r.code.Scope != "":
return r.code.Scope
case r.access.Scope != "":
return r.access.Scope
case r.refresh.Scope != "":
return r.refresh.Scope
}
return ""
}
func (r *record) SetScope(scope string) {
r.code.Scope = scope
r.access.Scope = scope
r.refresh.Scope = scope
}
func (r *record) GetRedirectURI() string {
return r.code.RedirectURL
}
func (r *record) SetRedirectURI(redirectURL string) {
r.code.RedirectURL = redirectURL
}
func (r *record) GetCode() string {
return r.code.Code
}
func (r *record) SetCode(code string) {
r.code.Code = code
}
func (r *record) GetCodeCreateAt() time.Time {
return r.code.CreatedAt
}
func (r *record) SetCodeCreateAt(time time.Time) {
r.code.CreatedAt = time
}
func (r *record) GetCodeExpiresIn() time.Duration {
return r.code.ExpiresAt.Sub(r.code.CreatedAt)
}
func (r *record) SetCodeExpiresIn(duration time.Duration) {
r.code.ExpiresAt = r.code.CreatedAt.Add(duration)
}
func (r *record) GetCodeChallenge() string {
return r.code.Challenge
}
func (r *record) SetCodeChallenge(challenge string) {
r.code.Challenge = challenge
}
func (r *record) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
if r.code.ChallengeMethod == string(oauth2.CodeChallengeS256) {
return oauth2.CodeChallengeS256
}
return oauth2.CodeChallengePlain
}
func (r *record) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
r.code.ChallengeMethod = string(method)
}
func (r *record) GetAccess() string {
return r.access.Token
}
func (r *record) SetAccess(token string) {
r.access.Token = token
}
func (r *record) GetAccessCreateAt() time.Time {
return r.access.CreatedAt
}
func (r *record) SetAccessCreateAt(time time.Time) {
r.access.CreatedAt = time
}
func (r *record) GetAccessExpiresIn() time.Duration {
return r.access.ExpiresAt.Sub(r.access.CreatedAt)
}
func (r *record) SetAccessExpiresIn(duration time.Duration) {
r.access.ExpiresAt = r.access.CreatedAt.Add(duration)
}
func (r *record) GetRefresh() string {
return r.refresh.Token
}
func (r *record) SetRefresh(token string) {
r.refresh.Token = token
}
func (r *record) GetRefreshCreateAt() time.Time {
return r.refresh.CreatedAt
}
func (r *record) SetRefreshCreateAt(time time.Time) {
r.refresh.CreatedAt = time
}
func (r *record) GetRefreshExpiresIn() time.Duration {
return r.refresh.ExpiresAt.Sub(r.refresh.CreatedAt)
}
func (r *record) SetRefreshExpiresIn(duration time.Duration) {
r.refresh.ExpiresAt = r.refresh.CreatedAt.Add(duration)
}