// Copyright (C) 2022 Storj Labs, Inc. // See LICENSE for copying information. package oidc import ( "context" "time" "github.com/go-oauth2/oauth2/v4" "storj.io/common/uuid" ) // clientStore provides a simple adapter for the oauth implementation. type clientStore struct { clients OAuthClients } func (c *clientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { uid, err := uuid.FromString(id) if err != nil { return nil, err } return c.clients.Get(ctx, uid) } // tokenStore provides a simple adapter for the oauth implementation. type tokenStore struct { codes OAuthCodes tokens OAuthTokens } func (t *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) (err error) { var code OAuthCode var access, refresh OAuthToken if r, ok := info.(*record); ok { code = r.code access = r.access refresh = r.refresh } else { clientID, err := uuid.FromString(info.GetClientID()) if err != nil { return err } userID, err := uuid.FromString(info.GetUserID()) if err != nil { return err } if c := info.GetCode(); c != "" { code.ClientID = clientID code.UserID = userID code.Scope = info.GetScope() code.RedirectURL = info.GetRedirectURI() code.Challenge = info.GetCodeChallenge() code.ChallengeMethod = string(info.GetCodeChallengeMethod()) code.Code = c code.CreatedAt = info.GetCodeCreateAt() code.ExpiresAt = code.CreatedAt.Add(info.GetCodeExpiresIn()) } if a := info.GetAccess(); a != "" { access.ClientID = clientID access.UserID = userID access.Scope = info.GetScope() access.Kind = KindAccessToken access.Token = a access.CreatedAt = info.GetAccessCreateAt() access.ExpiresAt = access.CreatedAt.Add(info.GetAccessExpiresIn()) } if r := info.GetRefresh(); r != "" { refresh.ClientID = clientID refresh.UserID = userID refresh.Scope = info.GetScope() refresh.Kind = KindRefreshToken refresh.Token = r refresh.CreatedAt = info.GetRefreshCreateAt() refresh.ExpiresAt = refresh.CreatedAt.Add(info.GetRefreshExpiresIn()) } } if code.Code != "" { err := t.codes.Create(ctx, code) if err != nil { return err } } if access.Token != "" { err := t.tokens.Create(ctx, access) if err != nil { return err } } if refresh.Token != "" { err := t.tokens.Create(ctx, refresh) if err != nil { return err } } return nil } func (t *tokenStore) RemoveByCode(ctx context.Context, code string) error { return t.codes.Claim(ctx, code) } func (t *tokenStore) RemoveByAccess(ctx context.Context, access string) error { return nil // unsupported by current configuration } func (t *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { return nil // unsupported by current configuration } func (t *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { oauthCode, err := t.codes.Get(ctx, code) if err != nil { return nil, err } return &record{code: oauthCode}, nil } func (t *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { oauthToken, err := t.tokens.Get(ctx, KindAccessToken, access) if err != nil { return nil, err } return &record{access: oauthToken}, nil } func (t *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { oauthToken, err := t.tokens.Get(ctx, KindRefreshToken, refresh) if err != nil { return nil, err } return &record{refresh: oauthToken}, nil } type record struct { code OAuthCode access OAuthToken refresh OAuthToken } func (r *record) New() oauth2.TokenInfo { return &record{} } func (r *record) GetClientID() string { switch { case !r.code.ClientID.IsZero(): return r.code.ClientID.String() case !r.access.ClientID.IsZero(): return r.access.ClientID.String() case !r.refresh.ClientID.IsZero(): return r.refresh.ClientID.String() } return "" } func (r *record) SetClientID(s string) { clientID, err := uuid.FromString(s) if err != nil { return } r.code.ClientID = clientID r.access.ClientID = clientID r.refresh.ClientID = clientID } func (r *record) GetUserID() string { switch { case !r.code.UserID.IsZero(): return r.code.UserID.String() case !r.access.UserID.IsZero(): return r.access.UserID.String() case !r.refresh.UserID.IsZero(): return r.refresh.UserID.String() } return "" } func (r *record) SetUserID(s string) { userID, err := uuid.FromString(s) if err != nil { return } r.code.ClientID = userID r.access.ClientID = userID r.refresh.ClientID = userID } func (r *record) GetScope() string { switch { case r.code.Scope != "": return r.code.Scope case r.access.Scope != "": return r.access.Scope case r.refresh.Scope != "": return r.refresh.Scope } return "" } func (r *record) SetScope(scope string) { r.code.Scope = scope r.access.Scope = scope r.refresh.Scope = scope } func (r *record) GetRedirectURI() string { return r.code.RedirectURL } func (r *record) SetRedirectURI(redirectURL string) { r.code.RedirectURL = redirectURL } func (r *record) GetCode() string { return r.code.Code } func (r *record) SetCode(code string) { r.code.Code = code } func (r *record) GetCodeCreateAt() time.Time { return r.code.CreatedAt } func (r *record) SetCodeCreateAt(time time.Time) { r.code.CreatedAt = time } func (r *record) GetCodeExpiresIn() time.Duration { return r.code.ExpiresAt.Sub(r.code.CreatedAt) } func (r *record) SetCodeExpiresIn(duration time.Duration) { r.code.ExpiresAt = r.code.CreatedAt.Add(duration) } func (r *record) GetCodeChallenge() string { return r.code.Challenge } func (r *record) SetCodeChallenge(challenge string) { r.code.Challenge = challenge } func (r *record) GetCodeChallengeMethod() oauth2.CodeChallengeMethod { if r.code.ChallengeMethod == string(oauth2.CodeChallengeS256) { return oauth2.CodeChallengeS256 } return oauth2.CodeChallengePlain } func (r *record) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) { r.code.ChallengeMethod = string(method) } func (r *record) GetAccess() string { return r.access.Token } func (r *record) SetAccess(token string) { r.access.Token = token } func (r *record) GetAccessCreateAt() time.Time { return r.access.CreatedAt } func (r *record) SetAccessCreateAt(time time.Time) { r.access.CreatedAt = time } func (r *record) GetAccessExpiresIn() time.Duration { return r.access.ExpiresAt.Sub(r.access.CreatedAt) } func (r *record) SetAccessExpiresIn(duration time.Duration) { r.access.ExpiresAt = r.access.CreatedAt.Add(duration) } func (r *record) GetRefresh() string { return r.refresh.Token } func (r *record) SetRefresh(token string) { r.refresh.Token = token } func (r *record) GetRefreshCreateAt() time.Time { return r.refresh.CreatedAt } func (r *record) SetRefreshCreateAt(time time.Time) { r.refresh.CreatedAt = time } func (r *record) GetRefreshExpiresIn() time.Duration { return r.refresh.ExpiresAt.Sub(r.refresh.CreatedAt) } func (r *record) SetRefreshExpiresIn(duration time.Duration) { r.refresh.ExpiresAt = r.refresh.CreatedAt.Add(duration) }