diff --git a/private/api/authentication.go b/private/api/authentication.go index 9da68613b..b317bbe03 100644 --- a/private/api/authentication.go +++ b/private/api/authentication.go @@ -12,4 +12,6 @@ import ( type Auth interface { // IsAuthenticated checks if request is performed with all needed authorization credentials. IsAuthenticated(ctx context.Context, r *http.Request, isCookieAuth, isKeyAuth bool) (context.Context, error) + // RemoveAuthCookie indicates to the client that the authentication cookie should be removed. + RemoveAuthCookie(w http.ResponseWriter) } diff --git a/private/apigen/gogen.go b/private/apigen/gogen.go index 7773e8f56..449cda26d 100644 --- a/private/apigen/gogen.go +++ b/private/apigen/gogen.go @@ -178,6 +178,9 @@ func (a *API) generateGo() ([]byte, error) { p("ctx, err = h.auth.IsAuthenticated(ctx, r, true, false)") } p("if err != nil {") + if !endpoint.NoCookieAuth { + p("h.auth.RemoveAuthCookie(w)") + } p("api.ServeError(h.log, w, http.StatusUnauthorized, err)") p("return") p("}") diff --git a/private/testplanet/satellite.go b/private/testplanet/satellite.go index 474dfa424..1c64af20b 100644 --- a/private/testplanet/satellite.go +++ b/private/testplanet/satellite.go @@ -10,9 +10,7 @@ import ( "path/filepath" "runtime/pprof" "strconv" - "time" - "github.com/pquerna/otp/totp" "github.com/spf13/pflag" "github.com/zeebo/errs" "go.uber.org/zap" @@ -40,7 +38,6 @@ import ( "storj.io/storj/satellite/audit" "storj.io/storj/satellite/compensation" "storj.io/storj/satellite/console" - "storj.io/storj/satellite/console/consoleauth" "storj.io/storj/satellite/console/consoleweb" "storj.io/storj/satellite/contact" "storj.io/storj/satellite/gc" @@ -237,11 +234,11 @@ func (system *Satellite) AddUser(ctx context.Context, newUser console.CreateUser return nil, err } - authCtx, err := system.AuthenticatedContext(ctx, user.ID) + userCtx, err := system.UserContext(ctx, user.ID) if err != nil { return nil, err } - _, err = system.API.Console.Service.Payments().SetupAccount(authCtx) + _, err = system.API.Console.Service.Payments().SetupAccount(userCtx) if err != nil { return nil, err } @@ -253,11 +250,11 @@ func (system *Satellite) AddUser(ctx context.Context, newUser console.CreateUser func (system *Satellite) AddProject(ctx context.Context, ownerID uuid.UUID, name string) (_ *console.Project, err error) { defer mon.Task()(&ctx)(&err) - authCtx, err := system.AuthenticatedContext(ctx, ownerID) + ctx, err = system.UserContext(ctx, ownerID) if err != nil { return nil, err } - project, err := system.API.Console.Service.CreateProject(authCtx, console.ProjectInfo{ + project, err := system.API.Console.Service.CreateProject(ctx, console.ProjectInfo{ Name: name, }) if err != nil { @@ -266,8 +263,8 @@ func (system *Satellite) AddProject(ctx context.Context, ownerID uuid.UUID, name return project, nil } -// AuthenticatedContext creates context with authentication date for given user. -func (system *Satellite) AuthenticatedContext(ctx context.Context, userID uuid.UUID) (_ context.Context, err error) { +// UserContext creates context with user. +func (system *Satellite) UserContext(ctx context.Context, userID uuid.UUID) (_ context.Context, err error) { defer mon.Task()(&ctx)(&err) user, err := system.API.Console.Service.GetUser(ctx, userID) @@ -275,26 +272,7 @@ func (system *Satellite) AuthenticatedContext(ctx context.Context, userID uuid.U return nil, err } - // we are using full name as a password - request := console.AuthUser{Email: user.Email, Password: user.FullName} - if user.MFAEnabled { - code, err := totp.GenerateCode(user.MFASecretKey, time.Now()) - if err != nil { - return nil, err - } - request.MFAPasscode = code - } - token, err := system.API.Console.Service.Token(ctx, request) - if err != nil { - return nil, err - } - - auth, err := system.API.Console.Service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token))) - if err != nil { - return nil, err - } - - return console.WithAuth(ctx, auth), nil + return console.WithUser(ctx, user), nil } // Close closes all the subsystems in the Satellite system. diff --git a/private/testplanet/uplink.go b/private/testplanet/uplink.go index 737434811..063e72f4b 100644 --- a/private/testplanet/uplink.go +++ b/private/testplanet/uplink.go @@ -147,11 +147,11 @@ func (planet *Planet) newUplink(ctx context.Context, name string) (_ *Uplink, er return nil, err } - authCtx, err := satellite.AuthenticatedContext(ctx, user.ID) + userCtx, err := satellite.UserContext(ctx, user.ID) if err != nil { return nil, err } - _, apiKey, err := consoleAPI.Service.CreateAPIKey(authCtx, project.ID, "root") + _, apiKey, err := consoleAPI.Service.CreateAPIKey(userCtx, project.ID, "root") if err != nil { return nil, err } diff --git a/satellite/console/auth.go b/satellite/console/auth.go deleted file mode 100644 index 1391eceed..000000000 --- a/satellite/console/auth.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package console - -import ( - "context" - - "github.com/zeebo/errs" - - "storj.io/storj/satellite/console/consoleauth" -) - -// TODO: change to JWT or Macaroon based auth - -// key is a context value key type. -type key int - -// authKey is context key for Authorization. -const authKey key = 0 - -// requestKey is context key for Requests. -const requestKey key = 1 - -// ErrUnauthorized is error class for authorization related errors. -var ErrUnauthorized = errs.Class("unauthorized") - -// Authorization contains auth info of authorized User. -type Authorization struct { - User User - Claims consoleauth.Claims -} - -// WithAuth creates new context with Authorization. -func WithAuth(ctx context.Context, auth Authorization) context.Context { - return context.WithValue(ctx, authKey, auth) -} - -// WithAuthFailure creates new context with authorization failure. -func WithAuthFailure(ctx context.Context, err error) context.Context { - return context.WithValue(ctx, authKey, err) -} - -// GetAuth gets Authorization from context. -func GetAuth(ctx context.Context) (Authorization, error) { - value := ctx.Value(authKey) - - if auth, ok := value.(Authorization); ok { - return auth, nil - } - - if err, ok := value.(error); ok { - return Authorization{}, ErrUnauthorized.Wrap(err) - } - - return Authorization{}, ErrUnauthorized.New(unauthorizedErrMsg) -} diff --git a/satellite/console/consoleauth/service.go b/satellite/console/consoleauth/service.go index 1bd9ed360..c34c42067 100644 --- a/satellite/console/consoleauth/service.go +++ b/satellite/console/consoleauth/service.go @@ -5,6 +5,7 @@ package consoleauth import ( "context" + "crypto/subtle" "encoding/base64" "time" @@ -17,7 +18,7 @@ var mon = monkit.Package() // Config contains configuration parameters for console auth. type Config struct { - TokenExpirationTime time.Duration `help:"expiration time for auth tokens, account recovery tokens, and activation tokens" default:"24h"` + TokenExpirationTime time.Duration `help:"expiration time for account recovery and activation tokens" default:"24h"` } // Service handles creating, signing, and checking the expiration of auth tokens. @@ -63,25 +64,35 @@ func (s *Service) createToken(ctx context.Context, claims *Claims) (_ string, er } token := Token{Payload: json} - err = s.SignToken(&token) + signature, err := s.SignToken(token) if err != nil { return "", err } + token.Signature = signature return token.String(), nil } -// SignToken signs token. -func (s *Service) SignToken(token *Token) error { +// SignToken returns token signature. +func (s *Service) SignToken(token Token) ([]byte, error) { encoded := base64.URLEncoding.EncodeToString(token.Payload) signature, err := s.Signer.Sign([]byte(encoded)) if err != nil { - return err + return nil, err } - token.Signature = signature - return nil + return signature, nil +} + +// ValidateToken determines token validity using its signature. +func (s *Service) ValidateToken(token Token) (bool, error) { + signature, err := s.SignToken(token) + if err != nil { + return false, err + } + + return subtle.ConstantTimeCompare(signature, token.Signature) == 1, nil } // IsExpired returns whether token is expired. diff --git a/satellite/console/consoleweb/consoleapi/analytics.go b/satellite/console/consoleweb/consoleapi/analytics.go index 53fd7ce6b..66b67ecf0 100644 --- a/satellite/console/consoleweb/consoleapi/analytics.go +++ b/satellite/console/consoleweb/consoleapi/analytics.go @@ -59,15 +59,15 @@ func (a *Analytics) EventTriggered(w http.ResponseWriter, r *http.Request) { a.serveJSONError(w, http.StatusInternalServerError, err) } - auth, err := console.GetAuth(ctx) + user, err := console.GetUser(ctx) if err != nil { a.serveJSONError(w, http.StatusUnauthorized, err) return } if et.Link != "" { - a.analytics.TrackLinkEvent(et.EventName, auth.User.ID, auth.User.Email, et.Link) + a.analytics.TrackLinkEvent(et.EventName, user.ID, user.Email, et.Link) } else { - a.analytics.TrackEvent(et.EventName, auth.User.ID, auth.User.Email) + a.analytics.TrackEvent(et.EventName, user.ID, user.Email) } w.WriteHeader(http.StatusOK) } @@ -88,13 +88,13 @@ func (a *Analytics) PageEventTriggered(w http.ResponseWriter, r *http.Request) { a.serveJSONError(w, http.StatusInternalServerError, err) } - auth, err := console.GetAuth(ctx) + user, err := console.GetUser(ctx) if err != nil { a.serveJSONError(w, http.StatusUnauthorized, err) return } - a.analytics.PageVisitEvent(pv.PageName, auth.User.ID, auth.User.Email) + a.analytics.PageVisitEvent(pv.PageName, user.ID, user.Email) w.WriteHeader(http.StatusOK) } diff --git a/satellite/console/consoleweb/consoleapi/api.gen.go b/satellite/console/consoleweb/consoleapi/api.gen.go index ca649c0e0..138261463 100644 --- a/satellite/console/consoleweb/consoleapi/api.gen.go +++ b/satellite/console/consoleweb/consoleapi/api.gen.go @@ -165,6 +165,7 @@ func (h *ProjectManagementHandler) handleGenGetBucketUsageRollups(w http.Respons ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -208,6 +209,7 @@ func (h *ProjectManagementHandler) handleGenCreateProject(w http.ResponseWriter, ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -239,6 +241,7 @@ func (h *ProjectManagementHandler) handleGenUpdateProject(w http.ResponseWriter, ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -282,6 +285,7 @@ func (h *ProjectManagementHandler) handleGenDeleteProject(w http.ResponseWriter, ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -313,6 +317,7 @@ func (h *ProjectManagementHandler) handleGenGetUsersProjects(w http.ResponseWrit ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -338,6 +343,7 @@ func (h *APIKeyManagementHandler) handleGenCreateAPIKey(w http.ResponseWriter, r ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } @@ -369,6 +375,7 @@ func (h *UserManagementHandler) handleGenGetUser(w http.ResponseWriter, r *http. ctx, err = h.auth.IsAuthenticated(ctx, r, true, true) if err != nil { + h.auth.RemoveAuthCookie(w) api.ServeError(h.log, w, http.StatusUnauthorized, err) return } diff --git a/satellite/console/consoleweb/consoleapi/apikeys_test.go b/satellite/console/consoleweb/consoleapi/apikeys_test.go index e058d8912..311a32cf0 100644 --- a/satellite/console/consoleweb/consoleapi/apikeys_test.go +++ b/satellite/console/consoleweb/consoleapi/apikeys_test.go @@ -70,7 +70,7 @@ func Test_DeleteAPIKeyByNameAndProjectID(t *testing.T) { cookie := http.Cookie{ Name: "_tokenKey", Path: "/", - Value: token, + Value: token.String(), Expires: expire, } diff --git a/satellite/console/consoleweb/consoleapi/auth.go b/satellite/console/consoleweb/consoleapi/auth.go index bb8f79188..39b43e9e0 100644 --- a/satellite/console/consoleweb/consoleapi/auth.go +++ b/satellite/console/consoleweb/consoleapi/auth.go @@ -90,6 +90,13 @@ func (a *Auth) Token(w http.ResponseWriter, r *http.Request) { return } + tokenRequest.UserAgent = r.UserAgent() + tokenRequest.IP, err = web.GetRequestIP(r) + if err != nil { + a.serveJSONError(w, err) + return + } + token, err := a.service.Token(ctx, tokenRequest) if err != nil { if console.ErrMFAMissing.Has(err) { @@ -104,7 +111,7 @@ func (a *Auth) Token(w http.ResponseWriter, r *http.Request) { a.cookieAuth.SetTokenCookie(w, token) w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(token) + err = json.NewEncoder(w).Encode(token.String()) if err != nil { a.log.Error("token handler could not encode token response", zap.Error(ErrAuthAPI.Wrap(err))) return @@ -116,9 +123,21 @@ func (a *Auth) Logout(w http.ResponseWriter, r *http.Request) { ctx := r.Context() defer mon.Task()(&ctx)(nil) - a.cookieAuth.RemoveTokenCookie(w) - w.Header().Set("Content-Type", "application/json") + + token, err := a.cookieAuth.GetToken(r) + if err != nil { + a.serveJSONError(w, err) + return + } + + err = a.service.DeleteSessionByToken(ctx, token) + if err != nil { + a.serveJSONError(w, err) + return + } + + a.cookieAuth.RemoveTokenCookie(w) } // replaceURLCharacters replaces slash, colon, and dot characters in a string with a hyphen. @@ -401,27 +420,27 @@ func (a *Auth) GetAccount(w http.ResponseWriter, r *http.Request) { MFARecoveryCodeCount int `json:"mfaRecoveryCodeCount"` } - auth, err := console.GetAuth(ctx) + consoleUser, err := console.GetUser(ctx) if err != nil { a.serveJSONError(w, err) return } - user.ShortName = auth.User.ShortName - user.FullName = auth.User.FullName - user.Email = auth.User.Email - user.ID = auth.User.ID - user.PartnerID = auth.User.PartnerID - user.UserAgent = auth.User.UserAgent - user.ProjectLimit = auth.User.ProjectLimit - user.IsProfessional = auth.User.IsProfessional - user.CompanyName = auth.User.CompanyName - user.Position = auth.User.Position - user.EmployeeCount = auth.User.EmployeeCount - user.HaveSalesContact = auth.User.HaveSalesContact - user.PaidTier = auth.User.PaidTier - user.MFAEnabled = auth.User.MFAEnabled - user.MFARecoveryCodeCount = len(auth.User.MFARecoveryCodes) + user.ShortName = consoleUser.ShortName + user.FullName = consoleUser.FullName + user.Email = consoleUser.Email + user.ID = consoleUser.ID + user.PartnerID = consoleUser.PartnerID + user.UserAgent = consoleUser.UserAgent + user.ProjectLimit = consoleUser.ProjectLimit + user.IsProfessional = consoleUser.IsProfessional + user.CompanyName = consoleUser.CompanyName + user.Position = consoleUser.Position + user.EmployeeCount = consoleUser.EmployeeCount + user.HaveSalesContact = consoleUser.HaveSalesContact + user.PaidTier = consoleUser.PaidTier + user.MFAEnabled = consoleUser.MFAEnabled + user.MFARecoveryCodeCount = len(consoleUser.MFARecoveryCodes) w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(&user) diff --git a/satellite/console/consoleweb/consoleapi/auth_test.go b/satellite/console/consoleweb/consoleapi/auth_test.go index 91564d947..4072d4521 100644 --- a/satellite/console/consoleweb/consoleapi/auth_test.go +++ b/satellite/console/consoleweb/consoleapi/auth_test.go @@ -374,7 +374,7 @@ func TestMFAEndpoints(t *testing.T) { req.AddCookie(&http.Cookie{ Name: "_tokenKey", Path: "/", - Value: token, + Value: token.String(), Expires: time.Now().AddDate(0, 0, 1), }) @@ -599,21 +599,19 @@ func TestResetPasswordEndpoint(t *testing.T) { token = getNewResetToken() // Enable MFA. - getNewAuthContext := func() context.Context { - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) - require.NoError(t, err) - return authCtx - } - authCtx := getNewAuthContext() - - key, err := service.ResetMFASecretKey(authCtx) + userCtx, err := sat.UserContext(ctx, user.ID) + require.NoError(t, err) + + key, err := service.ResetMFASecretKey(userCtx) + require.NoError(t, err) + + userCtx, err = sat.UserContext(ctx, user.ID) require.NoError(t, err) - authCtx = getNewAuthContext() passcode, err := console.NewMFAPasscode(key, token.CreatedAt) require.NoError(t, err) - err = service.EnableUserMFA(authCtx, passcode, token.CreatedAt) + err = service.EnableUserMFA(userCtx, passcode, token.CreatedAt) require.NoError(t, err) status, mfaError = tryPasswordReset(token.Secret.String(), newPass, "", "") diff --git a/satellite/console/consoleweb/consoleapi/buckets_test.go b/satellite/console/consoleweb/consoleapi/buckets_test.go index 5639845bc..c4ce492e2 100644 --- a/satellite/console/consoleweb/consoleapi/buckets_test.go +++ b/satellite/console/consoleweb/consoleapi/buckets_test.go @@ -76,7 +76,7 @@ func Test_AllBucketNames(t *testing.T) { cookie := http.Cookie{ Name: "_tokenKey", Path: "/", - Value: token, + Value: token.String(), Expires: expire, } diff --git a/satellite/console/consoleweb/consoleapi/usagelimits_test.go b/satellite/console/consoleweb/consoleapi/usagelimits_test.go index 55f51d75e..72e2d097a 100644 --- a/satellite/console/consoleweb/consoleapi/usagelimits_test.go +++ b/satellite/console/consoleweb/consoleapi/usagelimits_test.go @@ -89,7 +89,7 @@ func Test_TotalUsageLimits(t *testing.T) { cookie := http.Cookie{ Name: "_tokenKey", Path: "/", - Value: token, + Value: token.String(), Expires: expire, } @@ -203,7 +203,7 @@ func Test_DailyUsage(t *testing.T) { cookie := http.Cookie{ Name: "_tokenKey", Path: "/", - Value: token, + Value: token.String(), Expires: expire, } diff --git a/satellite/console/consoleweb/consoleql/mutation_test.go b/satellite/console/consoleweb/consoleql/mutation_test.go index 23ec4e319..885caf19d 100644 --- a/satellite/console/consoleweb/consoleql/mutation_test.go +++ b/satellite/console/consoleweb/consoleql/mutation_test.go @@ -113,6 +113,7 @@ func TestGraphqlMutation(t *testing.T) { console.Config{ PasswordCost: console.TestPasswordCost, DefaultProjectLimit: 5, + SessionDuration: time.Hour, }, ) require.NoError(t, err) @@ -164,15 +165,13 @@ func TestGraphqlMutation(t *testing.T) { token, err := service.Token(ctx, console.AuthUser{Email: createUser.Email, Password: createUser.Password}) require.NoError(t, err) - sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token))) + userCtx, err := service.TokenAuth(ctx, token, time.Now()) require.NoError(t, err) - authCtx := console.WithAuth(ctx, sauth) - testQuery := func(t *testing.T, query string) (interface{}, error) { result := graphql.Do(graphql.Params{ Schema: schema, - Context: authCtx, + Context: userCtx, RequestString: query, RootObject: rootObject, }) @@ -190,11 +189,9 @@ func TestGraphqlMutation(t *testing.T) { token, err = service.Token(ctx, console.AuthUser{Email: rootUser.Email, Password: createUser.Password}) require.NoError(t, err) - sauth, err = service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token))) + userCtx, err = service.TokenAuth(ctx, token, time.Now()) require.NoError(t, err) - authCtx = console.WithAuth(ctx, sauth) - var projectIDField string t.Run("Create project mutation", func(t *testing.T) { projectInfo := console.ProjectInfo{ @@ -223,14 +220,14 @@ func TestGraphqlMutation(t *testing.T) { projectID, err := uuid.FromString(projectIDField) require.NoError(t, err) - project, err := service.GetProject(authCtx, projectID) + project, err := service.GetProject(userCtx, projectID) require.NoError(t, err) require.Equal(t, rootUser.PartnerID, project.PartnerID) regTokenUser1, err := service.CreateRegToken(ctx, 1) require.NoError(t, err) - user1, err := service.CreateUser(authCtx, console.CreateUser{ + user1, err := service.CreateUser(userCtx, console.CreateUser{ FullName: "User1", Email: "u1@mail.test", Password: "123a123", @@ -254,7 +251,7 @@ func TestGraphqlMutation(t *testing.T) { regTokenUser2, err := service.CreateRegToken(ctx, 1) require.NoError(t, err) - user2, err := service.CreateUser(authCtx, console.CreateUser{ + user2, err := service.CreateUser(userCtx, console.CreateUser{ FullName: "User1", Email: "u2@mail.test", Password: "123a123", @@ -353,7 +350,7 @@ func TestGraphqlMutation(t *testing.T) { id, err := uuid.FromString(keyID) require.NoError(t, err) - info, err := service.GetAPIKeyInfo(authCtx, id) + info, err := service.GetAPIKeyInfo(userCtx, id) require.NoError(t, err) query := fmt.Sprintf( diff --git a/satellite/console/consoleweb/consoleql/project.go b/satellite/console/consoleweb/consoleql/project.go index e836e2b9a..4fd47f089 100644 --- a/satellite/console/consoleweb/consoleql/project.go +++ b/satellite/console/consoleweb/consoleql/project.go @@ -135,7 +135,7 @@ func graphqlProject(service *console.Service, types *TypeCreator) *graphql.Objec Resolve: func(p graphql.ResolveParams) (interface{}, error) { project, _ := p.Source.(*console.Project) - _, err := console.GetAuth(p.Context) + _, err := console.GetUser(p.Context) if err != nil { return nil, err } @@ -183,11 +183,6 @@ func graphqlProject(service *console.Service, types *TypeCreator) *graphql.Objec Resolve: func(p graphql.ResolveParams) (interface{}, error) { project, _ := p.Source.(*console.Project) - _, err := console.GetAuth(p.Context) - if err != nil { - return nil, err - } - cursor := cursorArgsToAPIKeysCursor(p.Args[CursorArg].(map[string]interface{})) page, err := service.GetAPIKeys(p.Context, project.ID, cursor) if err != nil { diff --git a/satellite/console/consoleweb/consoleql/query_test.go b/satellite/console/consoleweb/consoleql/query_test.go index 1af933bb1..e879fdd77 100644 --- a/satellite/console/consoleweb/consoleql/query_test.go +++ b/satellite/console/consoleweb/consoleql/query_test.go @@ -97,6 +97,7 @@ func TestGraphqlQuery(t *testing.T) { console.Config{ PasswordCost: console.TestPasswordCost, DefaultProjectLimit: 5, + SessionDuration: time.Hour, }, ) require.NoError(t, err) @@ -158,15 +159,13 @@ func TestGraphqlQuery(t *testing.T) { token, err := service.Token(ctx, console.AuthUser{Email: createUser.Email, Password: createUser.Password}) require.NoError(t, err) - sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token))) + userCtx, err := service.TokenAuth(ctx, token, time.Now()) require.NoError(t, err) - authCtx := console.WithAuth(ctx, sauth) - testQuery := func(t *testing.T, query string) interface{} { result := graphql.Do(graphql.Params{ Schema: schema, - Context: authCtx, + Context: userCtx, RequestString: query, RootObject: rootObject, }) @@ -179,7 +178,7 @@ func TestGraphqlQuery(t *testing.T) { return result.Data } - createdProject, err := service.CreateProject(authCtx, console.ProjectInfo{ + createdProject, err := service.CreateProject(userCtx, console.ProjectInfo{ Name: "TestProject", }) require.NoError(t, err) @@ -210,7 +209,7 @@ func TestGraphqlQuery(t *testing.T) { regTokenUser1, err := service.CreateRegToken(ctx, 2) require.NoError(t, err) - user1, err := service.CreateUser(authCtx, console.CreateUser{ + user1, err := service.CreateUser(userCtx, console.CreateUser{ FullName: "Mickey Last", ShortName: "Last", Password: "123a123", @@ -233,7 +232,7 @@ func TestGraphqlQuery(t *testing.T) { regTokenUser2, err := service.CreateRegToken(ctx, 2) require.NoError(t, err) - user2, err := service.CreateUser(authCtx, console.CreateUser{ + user2, err := service.CreateUser(userCtx, console.CreateUser{ FullName: "Dubas Name", ShortName: "Name", Email: "muu2@mail.test", @@ -253,7 +252,7 @@ func TestGraphqlQuery(t *testing.T) { user2.Email = "muu2@mail.test" }) - users, err := service.AddProjectMembers(authCtx, createdProject.ID, []string{ + users, err := service.AddProjectMembers(userCtx, createdProject.ID, []string{ user1.Email, user2.Email, }) @@ -316,10 +315,10 @@ func TestGraphqlQuery(t *testing.T) { assert.True(t, foundU2) }) - keyInfo1, _, err := service.CreateAPIKey(authCtx, createdProject.ID, "key1") + keyInfo1, _, err := service.CreateAPIKey(userCtx, createdProject.ID, "key1") require.NoError(t, err) - keyInfo2, _, err := service.CreateAPIKey(authCtx, createdProject.ID, "key2") + keyInfo2, _, err := service.CreateAPIKey(userCtx, createdProject.ID, "key2") require.NoError(t, err) t.Run("Project query api keys", func(t *testing.T) { @@ -372,7 +371,7 @@ func TestGraphqlQuery(t *testing.T) { assert.True(t, foundKey2) }) - project2, err := service.CreateProject(authCtx, console.ProjectInfo{ + project2, err := service.CreateProject(userCtx, console.ProjectInfo{ Name: "Project2", Description: "Test desc", }) diff --git a/satellite/console/consoleweb/consolewebauth/auth.go b/satellite/console/consoleweb/consolewebauth/auth.go index f9281aa88..7082546da 100644 --- a/satellite/console/consoleweb/consolewebauth/auth.go +++ b/satellite/console/consoleweb/consolewebauth/auth.go @@ -6,6 +6,8 @@ package consolewebauth import ( "net/http" "time" + + "storj.io/storj/satellite/console/consoleauth" ) // CookieSettings variable cookie settings. @@ -27,20 +29,25 @@ func NewCookieAuth(settings CookieSettings) *CookieAuth { } // GetToken retrieves token from request. -func (auth *CookieAuth) GetToken(r *http.Request) (string, error) { +func (auth *CookieAuth) GetToken(r *http.Request) (consoleauth.Token, error) { cookie, err := r.Cookie(auth.settings.Name) if err != nil { - return "", err + return consoleauth.Token{}, err } - return cookie.Value, nil + token, err := consoleauth.FromBase64URLString(cookie.Value) + if err != nil { + return consoleauth.Token{}, err + } + + return token, nil } // SetTokenCookie sets parametrized token cookie that is not accessible from js. -func (auth *CookieAuth) SetTokenCookie(w http.ResponseWriter, token string) { +func (auth *CookieAuth) SetTokenCookie(w http.ResponseWriter, token consoleauth.Token) { http.SetCookie(w, &http.Cookie{ Name: auth.settings.Name, - Value: token, + Value: token.String(), Path: auth.settings.Path, // TODO: get expiration from token Expires: time.Now().Add(time.Hour * 24), @@ -60,3 +67,8 @@ func (auth *CookieAuth) RemoveTokenCookie(w http.ResponseWriter) { SameSite: http.SameSiteStrictMode, }) } + +// GetTokenCookieName returns the name of the cookie storing the session token. +func (auth *CookieAuth) GetTokenCookieName() string { + return auth.settings.Name +} diff --git a/satellite/console/consoleweb/endpoints_test.go b/satellite/console/consoleweb/endpoints_test.go index 8b776846f..5e201fc2c 100644 --- a/satellite/console/consoleweb/endpoints_test.go +++ b/satellite/console/consoleweb/endpoints_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net/http" "net/http/cookiejar" + "net/url" "strings" "testing" "time" @@ -69,10 +70,13 @@ func TestAuth(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) } + var oldCookies []*http.Cookie + { // Get_AccountInfo resp, body := test.request(http.MethodGet, "/auth/account", nil) require.Equal(test.t, http.StatusOK, resp.StatusCode) require.Contains(test.t, body, "fullName") + oldCookies = resp.Cookies() var userIdentifier struct{ ID string } require.NoError(test.t, json.Unmarshal([]byte(body), &userIdentifier)) @@ -95,6 +99,16 @@ func TestAuth(t *testing.T) { require.Equal(test.t, http.StatusUnauthorized, resp.StatusCode) } + { // Get_AccountInfo shouldn't succeed with reused session cookie + satURL, err := url.Parse(test.url("")) + require.NoError(t, err) + test.client.Jar.SetCookies(satURL, oldCookies) + + resp, body := test.request(http.MethodGet, "/auth/account", nil) + require.Contains(test.t, body, "error") + require.Equal(test.t, http.StatusUnauthorized, resp.StatusCode) + } + { // repeated login attempts should end in too many requests hitRateLimiter := false for i := 0; i < 30; i++ { @@ -821,7 +835,7 @@ func (test *test) defaultUser() registeredUser { func (test *test) defaultProjectID() string { return test.planet.Uplinks[0].Projects[0].ID.String() } -func (test *test) login(email, password string) { +func (test *test) login(email, password string) Response { resp, body := test.request( http.MethodPost, "/auth/token", test.toJSON(map[string]string{ @@ -835,6 +849,8 @@ func (test *test) login(email, password string) { require.NoError(test.t, json.Unmarshal([]byte(body), &rawToken)) require.Equal(test.t, http.StatusOK, resp.StatusCode) require.Equal(test.t, rawToken, cookie.Value) + + return resp } func (test *test) registerUser(email, password string) registeredUser { diff --git a/satellite/console/consoleweb/server.go b/satellite/console/consoleweb/server.go index 214c4a7f0..650657cc4 100644 --- a/satellite/console/consoleweb/server.go +++ b/satellite/console/consoleweb/server.go @@ -35,7 +35,6 @@ import ( "storj.io/storj/private/web" "storj.io/storj/satellite/analytics" "storj.io/storj/satellite/console" - "storj.io/storj/satellite/console/consoleauth" "storj.io/storj/satellite/console/consoleweb/consoleapi" "storj.io/storj/satellite/console/consoleweb/consoleql" "storj.io/storj/satellite/console/consoleweb/consolewebauth" @@ -180,6 +179,62 @@ type templates struct { usageReport *template.Template } +// apiAuth exposes methods to control authentication process for each generated API endpoint. +type apiAuth struct { + server *Server +} + +// IsAuthenticated checks if request is performed with all needed authorization credentials. +func (a *apiAuth) IsAuthenticated(ctx context.Context, r *http.Request, isCookieAuth, isKeyAuth bool) (_ context.Context, err error) { + if isCookieAuth && isKeyAuth { + ctx, err = a.cookieAuth(ctx, r) + if err != nil { + ctx, err = a.keyAuth(ctx, r) + if err != nil { + return nil, err + } + } + } else if isCookieAuth { + ctx, err = a.cookieAuth(ctx, r) + if err != nil { + return nil, err + } + } else if isKeyAuth { + ctx, err = a.keyAuth(ctx, r) + if err != nil { + return nil, err + } + } + + return ctx, nil +} + +// cookieAuth returns an authenticated context by session cookie. +func (a *apiAuth) cookieAuth(ctx context.Context, r *http.Request) (context.Context, error) { + token, err := a.server.cookieAuth.GetToken(r) + if err != nil { + return nil, err + } + + return a.server.service.TokenAuth(ctx, token, time.Now()) +} + +// cookieAuth returns an authenticated context by api key. +func (a *apiAuth) keyAuth(ctx context.Context, r *http.Request) (context.Context, error) { + authToken := r.Header.Get("Authorization") + split := strings.Split(authToken, "Bearer ") + if len(split) != 2 { + return ctx, errs.New("authorization key format is incorrect. Should be 'Bearer '") + } + + return a.server.service.KeyAuth(ctx, split[1], time.Now()) +} + +// RemoveAuthCookie indicates to the client that the authentication cookie should be removed. +func (a *apiAuth) RemoveAuthCookie(w http.ResponseWriter) { + a.server.cookieAuth.RemoveTokenCookie(w) +} + // NewServer creates new instance of console server. func NewServer(logger *zap.Logger, config Config, service *console.Service, oidcService *oidc.Service, mailService *mailservice.Service, partners *rewards.PartnersService, analytics *analytics.Service, listener net.Listener, stripePublicKey string, pricing paymentsconfig.PricingValues, nodeURL storj.NodeURL) *Server { server := Server{ @@ -219,9 +274,9 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc router := mux.NewRouter() if server.config.GeneratedAPIEnabled { - consoleapi.NewProjectManagement(logger, server.service, router, server.service) - consoleapi.NewAPIKeyManagement(logger, server.service, router, server.service) - consoleapi.NewUserManagement(logger, server.service, router, server.service) + consoleapi.NewProjectManagement(logger, server.service, router, &apiAuth{&server}) + consoleapi.NewAPIKeyManagement(logger, server.service, router, &apiAuth{&server}) + consoleapi.NewUserManagement(logger, server.service, router, &apiAuth{&server}) } router.HandleFunc("/registrationToken/", server.createRegistrationTokenHandler) @@ -486,17 +541,15 @@ func (server *Server) withAuth(handler http.Handler) http.Handler { ctxWithAuth := func(ctx context.Context) context.Context { token, err := server.cookieAuth.GetToken(r) if err != nil { - return console.WithAuthFailure(ctx, err) + return console.WithUserFailure(ctx, console.ErrUnauthorized.Wrap(err)) } - ctx = consoleauth.WithAPIKey(ctx, []byte(token)) - - auth, err := server.service.Authorize(ctx) + newCtx, err := server.service.TokenAuth(ctx, token, time.Now()) if err != nil { - return console.WithAuthFailure(ctx, err) + return console.WithUserFailure(ctx, err) } - return console.WithAuth(ctx, auth) + return newCtx } ctx = ctxWithAuth(r.Context()) @@ -524,14 +577,12 @@ func (server *Server) bucketUsageReportHandler(w http.ResponseWriter, r *http.Re return } - auth, err := server.service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token))) + ctx, err = server.service.TokenAuth(ctx, token, time.Now()) if err != nil { server.serveError(w, http.StatusUnauthorized) return } - ctx = console.WithAuth(ctx, auth) - // parse query params projectID, err := uuid.FromString(r.URL.Query().Get("projectID")) if err != nil { @@ -625,7 +676,7 @@ func (server *Server) accountActivationHandler(w http.ResponseWriter, r *http.Re defer mon.Task()(&ctx)(nil) activationToken := r.URL.Query().Get("token") - token, err := server.service.ActivateAccount(ctx, activationToken) + user, err := server.service.ActivateAccount(ctx, activationToken) if err != nil { server.log.Error("activation: failed to activate account", zap.String("token", activationToken), @@ -645,6 +696,18 @@ func (server *Server) accountActivationHandler(w http.ResponseWriter, r *http.Re return } + ip, err := web.GetRequestIP(r) + if err != nil { + server.serveError(w, http.StatusInternalServerError) + return + } + + token, err := server.service.GenerateSessionToken(ctx, user.ID, user.Email, ip, r.UserAgent()) + if err != nil { + server.serveError(w, http.StatusInternalServerError) + return + } + server.cookieAuth.SetTokenCookie(w, token) http.Redirect(w, r, server.config.ExternalAddress, http.StatusTemporaryRedirect) @@ -897,10 +960,10 @@ func (server *Server) parseTemplates() (_ *templates, err error) { // NewUserIDRateLimiter constructs a RateLimiter that limits based on user ID. func NewUserIDRateLimiter(config web.RateLimiterConfig) *web.RateLimiter { return web.NewRateLimiter(config, func(r *http.Request) (string, error) { - auth, err := console.GetAuth(r.Context()) + user, err := console.GetUser(r.Context()) if err != nil { return "", err } - return auth.User.ID.String(), nil + return user.ID.String(), nil }) } diff --git a/satellite/console/consoleweb/server_test.go b/satellite/console/consoleweb/server_test.go index 9431be77f..7197e0af6 100644 --- a/satellite/console/consoleweb/server_test.go +++ b/satellite/console/consoleweb/server_test.go @@ -124,17 +124,19 @@ func TestUserIDRateLimiter(t *testing.T) { token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName}) require.NoError(t, err) + tokenStr := token.String() + if userNum == 1 { - firstToken = token + firstToken = tokenStr } // Expect burst number of successes. for burstNum := 0; burstNum < sat.Config.Console.RateLimit.Burst; burstNum++ { - require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(token)) + require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(tokenStr)) } // Expect failure. - require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(token)) + require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(tokenStr)) }) } diff --git a/satellite/console/mfa.go b/satellite/console/mfa.go index 18b9ac778..b51b1ebd5 100644 --- a/satellite/console/mfa.go +++ b/satellite/console/mfa.go @@ -86,12 +86,12 @@ func NewMFASecretKey() (string, error) { func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Time) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "enable MFA") + user, err := s.getUserAndAuditLog(ctx, "enable MFA") if err != nil { return Error.Wrap(err) } - valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t) + valid, err := ValidateMFAPasscode(passcode, user.MFASecretKey, t) if err != nil { return ErrValidation.Wrap(ErrMFAPasscode.Wrap(err)) } @@ -99,8 +99,8 @@ func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Tim return ErrValidation.Wrap(ErrMFAPasscode.New(mfaPasscodeInvalidErrMsg)) } - auth.User.MFAEnabled = true - err = s.store.Users().Update(ctx, &auth.User) + user.MFAEnabled = true + err = s.store.Users().Update(ctx, user) if err != nil { return Error.Wrap(err) } @@ -112,13 +112,11 @@ func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Tim func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Time, recoveryCode string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "disable MFA") + user, err := s.getUserAndAuditLog(ctx, "disable MFA") if err != nil { return Error.Wrap(err) } - user := &auth.User - if !user.MFAEnabled { return nil } @@ -139,7 +137,7 @@ func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Ti return ErrUnauthorized.Wrap(ErrMFARecoveryCode.New(mfaRecoveryInvalidErrMsg)) } } else if passcode != "" { - valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t) + valid, err := ValidateMFAPasscode(passcode, user.MFASecretKey, t) if err != nil { return ErrValidation.Wrap(ErrMFAPasscode.Wrap(err)) } @@ -150,10 +148,10 @@ func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Ti return ErrMFAMissing.New(mfaRequiredErrMsg) } - auth.User.MFAEnabled = false - auth.User.MFASecretKey = "" - auth.User.MFARecoveryCodes = nil - err = s.store.Users().Update(ctx, &auth.User) + user.MFAEnabled = false + user.MFASecretKey = "" + user.MFARecoveryCodes = nil + err = s.store.Users().Update(ctx, user) if err != nil { return Error.Wrap(err) } @@ -185,7 +183,7 @@ func NewMFARecoveryCode() (string, error) { func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "reset MFA secret key") + user, err := s.getUserAndAuditLog(ctx, "reset MFA secret key") if err != nil { return "", Error.Wrap(err) } @@ -195,8 +193,8 @@ func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error) return "", Error.Wrap(err) } - auth.User.MFASecretKey = key - err = s.store.Users().Update(ctx, &auth.User) + user.MFASecretKey = key + err = s.store.Users().Update(ctx, user) if err != nil { return "", Error.Wrap(err) } @@ -208,12 +206,12 @@ func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error) func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "reset MFA recovery codes") + user, err := s.getUserAndAuditLog(ctx, "reset MFA recovery codes") if err != nil { return nil, Error.Wrap(err) } - if !auth.User.MFAEnabled { + if !user.MFAEnabled { return nil, ErrUnauthorized.New(mfaRecoveryGenerationErrMsg) } @@ -225,9 +223,9 @@ func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, er } codes[i] = code } - auth.User.MFARecoveryCodes = codes + user.MFARecoveryCodes = codes - err = s.store.Users().Update(ctx, &auth.User) + err = s.store.Users().Update(ctx, user) if err != nil { return nil, Error.Wrap(err) } diff --git a/satellite/console/request.go b/satellite/console/request.go index a2cac0b0f..c9ce0d510 100644 --- a/satellite/console/request.go +++ b/satellite/console/request.go @@ -8,6 +8,9 @@ import ( "net/http" ) +// requestKey is context key for Requests. +const requestKey key = 1 + // WithRequest creates new context with *http.Request. func WithRequest(ctx context.Context, req *http.Request) context.Context { return context.WithValue(ctx, requestKey, req) diff --git a/satellite/console/service.go b/satellite/console/service.go index 89b996473..77124ace2 100644 --- a/satellite/console/service.go +++ b/satellite/console/service.go @@ -5,14 +5,12 @@ package console import ( "context" - "crypto/subtle" "fmt" "math" "math/big" "net/http" "net/mail" "sort" - "strings" "time" "github.com/spacemonkeygo/monkit/v3" @@ -71,6 +69,9 @@ var ( // Error describes internal console error. Error = errs.Class("console service") + // ErrUnauthorized is error class for authorization related errors. + ErrUnauthorized = errs.Class("unauthorized") + // ErrNoMembership is error type of not belonging to a specific project. ErrNoMembership = errs.Class("no membership") @@ -151,6 +152,7 @@ type Config struct { AsOfSystemTimeDuration time.Duration `help:"default duration for AS OF SYSTEM TIME" devDefault:"-5m" releaseDefault:"-5m" testDefault:"0"` LoginAttemptsWithoutPenalty int `help:"number of times user can try to login without penalty" default:"3"` FailedLoginPenalty float64 `help:"incremental duration of penalty for failed login attempts in minutes" default:"2.0"` + SessionDuration time.Duration `help:"duration a session is valid for" default:"168h"` UsageLimits UsageLimitsConfig Recaptcha RecaptchaConfig Hcaptcha HcaptchaConfig @@ -237,8 +239,8 @@ func (s *Service) auditLog(ctx context.Context, operation string, userID *uuid.U s.auditLogger.Info("console activity", fields...) } -func (s *Service) getAuthAndAuditLog(ctx context.Context, operation string, extra ...zap.Field) (Authorization, error) { - auth, err := GetAuth(ctx) +func (s *Service) getUserAndAuditLog(ctx context.Context, operation string, extra ...zap.Field) (*User, error) { + user, err := GetUser(ctx) if err != nil { sourceIP, forwardedForIP := getRequestingIP(ctx) s.auditLogger.Info("console activity unauthorized", @@ -249,10 +251,10 @@ func (s *Service) getAuthAndAuditLog(ctx context.Context, operation string, extr zap.String("source-ip", sourceIP), zap.String("forwarded-for-ip", forwardedForIP), ), extra...)...) - return Authorization{}, err + return nil, err } - s.auditLog(ctx, operation, &auth.User.ID, auth.User.Email, extra...) - return auth, nil + s.auditLog(ctx, operation, &user.ID, user.Email, extra...) + return user, nil } // Payments separates all payment related functionality. @@ -264,45 +266,45 @@ func (s *Service) Payments() Payments { func (payment Payments) SetupAccount(ctx context.Context) (_ payments.CouponType, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "setup payment account") + user, err := payment.service.getUserAndAuditLog(ctx, "setup payment account") if err != nil { return payments.NoCoupon, Error.Wrap(err) } - return payment.service.accounts.Setup(ctx, auth.User.ID, auth.User.Email, auth.User.SignupPromoCode) + return payment.service.accounts.Setup(ctx, user.ID, user.Email, user.SignupPromoCode) } // AccountBalance return account balance. func (payment Payments) AccountBalance(ctx context.Context) (balance payments.Balance, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "get account balance") + user, err := payment.service.getUserAndAuditLog(ctx, "get account balance") if err != nil { return payments.Balance{}, Error.Wrap(err) } - return payment.service.accounts.Balance(ctx, auth.User.ID) + return payment.service.accounts.Balance(ctx, user.ID) } // AddCreditCard is used to save new credit card and attach it to payment account. func (payment Payments) AddCreditCard(ctx context.Context, creditCardToken string) (err error) { defer mon.Task()(&ctx, creditCardToken)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "add credit card") + user, err := payment.service.getUserAndAuditLog(ctx, "add credit card") if err != nil { return Error.Wrap(err) } - err = payment.service.accounts.CreditCards().Add(ctx, auth.User.ID, creditCardToken) + err = payment.service.accounts.CreditCards().Add(ctx, user.ID, creditCardToken) if err != nil { return Error.Wrap(err) } - payment.service.analytics.TrackCreditCardAdded(auth.User.ID, auth.User.Email) + payment.service.analytics.TrackCreditCardAdded(user.ID, user.Email) - if !auth.User.PaidTier { + if !user.PaidTier { // put this user into the paid tier and convert projects to upgraded limits. - err = payment.service.store.Users().UpdatePaidTier(ctx, auth.User.ID, true, + err = payment.service.store.Users().UpdatePaidTier(ctx, user.ID, true, payment.service.config.UsageLimits.Bandwidth.Paid, payment.service.config.UsageLimits.Storage.Paid, payment.service.config.UsageLimits.Segment.Paid, @@ -312,7 +314,7 @@ func (payment Payments) AddCreditCard(ctx context.Context, creditCardToken strin return Error.Wrap(err) } - projects, err := payment.service.store.Projects().GetOwn(ctx, auth.User.ID) + projects, err := payment.service.store.Projects().GetOwn(ctx, user.ID) if err != nil { return Error.Wrap(err) } @@ -342,60 +344,60 @@ func (payment Payments) AddCreditCard(ctx context.Context, creditCardToken strin func (payment Payments) MakeCreditCardDefault(ctx context.Context, cardID string) (err error) { defer mon.Task()(&ctx, cardID)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "make credit card default") + user, err := payment.service.getUserAndAuditLog(ctx, "make credit card default") if err != nil { return Error.Wrap(err) } - return payment.service.accounts.CreditCards().MakeDefault(ctx, auth.User.ID, cardID) + return payment.service.accounts.CreditCards().MakeDefault(ctx, user.ID, cardID) } // ProjectsCharges returns how much money current user will be charged for each project which he owns. func (payment Payments) ProjectsCharges(ctx context.Context, since, before time.Time) (_ []payments.ProjectCharge, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "project charges") + user, err := payment.service.getUserAndAuditLog(ctx, "project charges") if err != nil { return nil, Error.Wrap(err) } - return payment.service.accounts.ProjectCharges(ctx, auth.User.ID, since, before) + return payment.service.accounts.ProjectCharges(ctx, user.ID, since, before) } // ListCreditCards returns a list of credit cards for a given payment account. func (payment Payments) ListCreditCards(ctx context.Context) (_ []payments.CreditCard, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "list credit cards") + user, err := payment.service.getUserAndAuditLog(ctx, "list credit cards") if err != nil { return nil, Error.Wrap(err) } - return payment.service.accounts.CreditCards().List(ctx, auth.User.ID) + return payment.service.accounts.CreditCards().List(ctx, user.ID) } // RemoveCreditCard is used to detach a credit card from payment account. func (payment Payments) RemoveCreditCard(ctx context.Context, cardID string) (err error) { defer mon.Task()(&ctx, cardID)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "remove credit card") + user, err := payment.service.getUserAndAuditLog(ctx, "remove credit card") if err != nil { return Error.Wrap(err) } - return payment.service.accounts.CreditCards().Remove(ctx, auth.User.ID, cardID) + return payment.service.accounts.CreditCards().Remove(ctx, user.ID, cardID) } // BillingHistory returns a list of billing history items for payment account. func (payment Payments) BillingHistory(ctx context.Context) (billingHistory []*BillingHistoryItem, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "get billing history") + user, err := payment.service.getUserAndAuditLog(ctx, "get billing history") if err != nil { return nil, Error.Wrap(err) } - invoices, couponUsages, err := payment.service.accounts.Invoices().ListWithDiscounts(ctx, auth.User.ID) + invoices, couponUsages, err := payment.service.accounts.Invoices().ListWithDiscounts(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -413,7 +415,7 @@ func (payment Payments) BillingHistory(ctx context.Context) (billingHistory []*B }) } - txsInfos, err := payment.service.accounts.StorjTokens().ListTransactionInfos(ctx, auth.User.ID) + txsInfos, err := payment.service.accounts.StorjTokens().ListTransactionInfos(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -432,7 +434,7 @@ func (payment Payments) BillingHistory(ctx context.Context) (billingHistory []*B }) } - charges, err := payment.service.accounts.Charges(ctx, auth.User.ID) + charges, err := payment.service.accounts.Charges(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -467,7 +469,7 @@ func (payment Payments) BillingHistory(ctx context.Context) (billingHistory []*B }) } - bonuses, err := payment.service.accounts.StorjTokens().ListDepositBonuses(ctx, auth.User.ID) + bonuses, err := payment.service.accounts.StorjTokens().ListDepositBonuses(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -497,12 +499,12 @@ func (payment Payments) BillingHistory(ctx context.Context) (billingHistory []*B func (payment Payments) TokenDeposit(ctx context.Context, amount int64) (_ *payments.Transaction, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "token deposit") + user, err := payment.service.getUserAndAuditLog(ctx, "token deposit") if err != nil { return nil, Error.Wrap(err) } - tx, err := payment.service.accounts.StorjTokens().Deposit(ctx, auth.User.ID, amount) + tx, err := payment.service.accounts.StorjTokens().Deposit(ctx, user.ID, amount) return tx, Error.Wrap(err) } @@ -511,12 +513,12 @@ func (payment Payments) TokenDeposit(ctx context.Context, amount int64) (_ *paym func (payment Payments) checkOutstandingInvoice(ctx context.Context) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "get outstanding invoices") + user, err := payment.service.getUserAndAuditLog(ctx, "get outstanding invoices") if err != nil { return err } - invoices, err := payment.service.accounts.Invoices().List(ctx, auth.User.ID) + invoices, err := payment.service.accounts.Invoices().List(ctx, user.ID) if err != nil { return err } @@ -528,7 +530,7 @@ func (payment Payments) checkOutstandingInvoice(ctx context.Context) (err error) } } - hasItems, err := payment.service.accounts.Invoices().CheckPendingItems(ctx, auth.User.ID) + hasItems, err := payment.service.accounts.Invoices().CheckPendingItems(ctx, user.ID) if err != nil { return err } @@ -543,7 +545,7 @@ func (payment Payments) checkOutstandingInvoice(ctx context.Context) (err error) func (payment Payments) checkProjectInvoicingStatus(ctx context.Context, projectID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - _, err = payment.service.getAuthAndAuditLog(ctx, "project invoicing status") + _, err = payment.service.getUserAndAuditLog(ctx, "project invoicing status") if err != nil { return Error.Wrap(err) } @@ -555,7 +557,7 @@ func (payment Payments) checkProjectInvoicingStatus(ctx context.Context, project func (payment Payments) checkProjectUsageStatus(ctx context.Context, projectID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - _, err = payment.service.getAuthAndAuditLog(ctx, "project usage status") + _, err = payment.service.getUserAndAuditLog(ctx, "project usage status") if err != nil { return Error.Wrap(err) } @@ -568,12 +570,12 @@ func (payment Payments) checkProjectUsageStatus(ctx context.Context, projectID u func (payment Payments) ApplyCouponCode(ctx context.Context, couponCode string) (coupon *payments.Coupon, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "apply coupon code") + user, err := payment.service.getUserAndAuditLog(ctx, "apply coupon code") if err != nil { return nil, Error.Wrap(err) } - coupon, err = payment.service.accounts.Coupons().ApplyCouponCode(ctx, auth.User.ID, couponCode) + coupon, err = payment.service.accounts.Coupons().ApplyCouponCode(ctx, user.ID, couponCode) if err != nil { return nil, Error.Wrap(err) } @@ -585,12 +587,12 @@ func (payment Payments) ApplyCouponCode(ctx context.Context, couponCode string) func (payment Payments) GetCoupon(ctx context.Context) (coupon *payments.Coupon, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "get coupon") + user, err := payment.service.getUserAndAuditLog(ctx, "get coupon") if err != nil { return nil, Error.Wrap(err) } - coupon, err = payment.service.accounts.Coupons().GetByUserID(ctx, auth.User.ID) + coupon, err = payment.service.accounts.Coupons().GetByUserID(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -765,53 +767,79 @@ func (s *Service) GeneratePasswordRecoveryToken(ctx context.Context, id uuid.UUI return resetPasswordToken.Secret.String(), nil } +// GenerateSessionToken creates a new session and returns the string representation of its token. +func (s *Service) GenerateSessionToken(ctx context.Context, userID uuid.UUID, email, ip, userAgent string) (consoleauth.Token, error) { + sessionID, err := uuid.New() + if err != nil { + return consoleauth.Token{}, Error.Wrap(err) + } + + _, err = s.store.WebappSessions().Create(ctx, sessionID, userID, ip, userAgent, time.Now().Add(s.config.SessionDuration)) + if err != nil { + return consoleauth.Token{}, err + } + + token := consoleauth.Token{Payload: sessionID.Bytes()} + + signature, err := s.tokens.SignToken(token) + if err != nil { + return consoleauth.Token{}, err + } + token.Signature = signature + + s.auditLog(ctx, "login", &userID, email) + + s.analytics.TrackSignedIn(userID, email) + + return token, nil +} + // ActivateAccount - is a method for activating user account after registration. -func (s *Service) ActivateAccount(ctx context.Context, activationToken string) (token string, err error) { +func (s *Service) ActivateAccount(ctx context.Context, activationToken string) (user *User, err error) { defer mon.Task()(&ctx)(&err) parsedActivationToken, err := consoleauth.FromBase64URLString(activationToken) if err != nil { - return "", Error.Wrap(err) + return nil, Error.Wrap(err) } - claims, err := s.authenticate(ctx, parsedActivationToken) + valid, err := s.tokens.ValidateToken(parsedActivationToken) if err != nil { - return "", err + return nil, Error.Wrap(err) + } + if !valid { + return nil, Error.New("incorrect signature") + } + + claims, err := consoleauth.FromJSON(parsedActivationToken.Payload) + if err != nil { + return nil, Error.Wrap(err) } if time.Now().After(claims.Expiration) { - return "", ErrTokenExpiration.New(activationTokenExpiredErrMsg) + return nil, ErrTokenExpiration.New(activationTokenExpiredErrMsg) } _, err = s.store.Users().GetByEmail(ctx, claims.Email) if err == nil { - return "", ErrEmailUsed.New(emailUsedErrMsg) + return nil, ErrEmailUsed.New(emailUsedErrMsg) } - user, err := s.store.Users().Get(ctx, claims.ID) + user, err = s.store.Users().Get(ctx, claims.ID) if err != nil { - return "", Error.Wrap(err) + return nil, Error.Wrap(err) } user.Status = Active err = s.store.Users().Update(ctx, user) if err != nil { - return "", Error.Wrap(err) + return nil, Error.Wrap(err) } s.auditLog(ctx, "activate account", &user.ID, user.Email) s.analytics.TrackAccountVerified(user.ID, user.Email) - // now that the account is activated, create a token to be stored in a cookie to log the user in. - token, err = s.tokens.CreateToken(ctx, user.ID, "") - if err != nil { - return "", err - } - s.auditLog(ctx, "login", &user.ID, user.Email) - - s.analytics.TrackSignedIn(user.ID, user.Email) - - return token, nil + return user, nil } // ResetPassword - is a method for resetting user password. @@ -901,8 +929,8 @@ func (s *Service) RevokeResetPasswordToken(ctx context.Context, resetPasswordTok return s.store.ResetPasswordTokens().Delete(ctx, secret) } -// Token authenticates User by credentials and returns auth token. -func (s *Service) Token(ctx context.Context, request AuthUser) (token string, err error) { +// Token authenticates User by credentials and returns session token. +func (s *Service) Token(ctx context.Context, request AuthUser) (token consoleauth.Token, err error) { defer mon.Task()(&ctx)(&err) mon.Counter("login_attempt").Inc(1) //mon:locked @@ -914,14 +942,14 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token string, er } else { mon.Counter("login_email_invalid").Inc(1) //mon:locked } - return "", ErrLoginCredentials.New(credentialsErrMsg) + return consoleauth.Token{}, ErrLoginCredentials.New(credentialsErrMsg) } now := time.Now() if user.LoginLockoutExpiration.After(now) { mon.Counter("login_locked_out").Inc(1) //mon:locked - return "", ErrLockedAccount.New(lockedAccountErrMsg) + return consoleauth.Token{}, ErrLockedAccount.New(lockedAccountErrMsg) } lockoutExpDate := now.Add(time.Duration(math.Pow(s.config.FailedLoginPenalty, float64(user.FailedLoginCount-1))) * time.Minute) @@ -951,16 +979,16 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token string, er if err != nil { err = handleLockAccount() if err != nil { - return "", err + return consoleauth.Token{}, err } mon.Counter("login_invalid_password").Inc(1) //mon:locked - return "", ErrLoginPassword.New(credentialsErrMsg) + return consoleauth.Token{}, ErrLoginPassword.New(credentialsErrMsg) } if user.MFAEnabled { if request.MFARecoveryCode != "" && request.MFAPasscode != "" { mon.Counter("login_mfa_conflict").Inc(1) //mon:locked - return "", ErrMFAConflict.New(mfaConflictErrMsg) + return consoleauth.Token{}, ErrMFAConflict.New(mfaConflictErrMsg) } if request.MFARecoveryCode != "" { @@ -976,10 +1004,10 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token string, er if !found { err = handleLockAccount() if err != nil { - return "", err + return consoleauth.Token{}, err } mon.Counter("login_mfa_recovery_failure").Inc(1) //mon:locked - return "", ErrMFARecoveryCode.New(mfaRecoveryInvalidErrMsg) + return consoleauth.Token{}, ErrMFARecoveryCode.New(mfaRecoveryInvalidErrMsg) } mon.Counter("login_mfa_recovery_success").Inc(1) //mon:locked @@ -988,30 +1016,30 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token string, er err = s.store.Users().Update(ctx, user) if err != nil { - return "", err + return consoleauth.Token{}, err } } else if request.MFAPasscode != "" { valid, err := ValidateMFAPasscode(request.MFAPasscode, user.MFASecretKey, time.Now()) if err != nil { err = handleLockAccount() if err != nil { - return "", err + return consoleauth.Token{}, err } - return "", ErrMFAPasscode.Wrap(err) + return consoleauth.Token{}, ErrMFAPasscode.Wrap(err) } if !valid { err = handleLockAccount() if err != nil { - return "", err + return consoleauth.Token{}, err } mon.Counter("login_mfa_passcode_failure").Inc(1) //mon:locked - return "", ErrMFAPasscode.New(mfaPasscodeInvalidErrMsg) + return consoleauth.Token{}, ErrMFAPasscode.New(mfaPasscodeInvalidErrMsg) } mon.Counter("login_mfa_passcode_success").Inc(1) //mon:locked } else { mon.Counter("login_mfa_missing").Inc(1) //mon:locked - return "", ErrMFAMissing.New(mfaRequiredErrMsg) + return consoleauth.Token{}, ErrMFAMissing.New(mfaRequiredErrMsg) } } @@ -1020,17 +1048,14 @@ func (s *Service) Token(ctx context.Context, request AuthUser) (token string, er user.LoginLockoutExpiration = time.Time{} err = s.store.Users().Update(ctx, user) if err != nil { - return "", err + return consoleauth.Token{}, err } } - token, err = s.tokens.CreateToken(ctx, user.ID, "") + token, err = s.GenerateSessionToken(ctx, user.ID, user.Email, request.IP, request.UserAgent) if err != nil { - return "", err + return consoleauth.Token{}, err } - s.auditLog(ctx, "login", &user.ID, user.Email) - - s.analytics.TrackSignedIn(user.ID, user.Email) mon.Counter("login_success").Inc(1) //mon:locked @@ -1064,7 +1089,7 @@ func (s *Service) GenGetUser(ctx context.Context) (*ResponseUser, api.HTTPError) var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get user") + user, err := s.getUserAndAuditLog(ctx, "get user") if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1072,36 +1097,36 @@ func (s *Service) GenGetUser(ctx context.Context) (*ResponseUser, api.HTTPError) } } - user := &ResponseUser{ - ID: auth.User.ID, - FullName: auth.User.FullName, - ShortName: auth.User.ShortName, - Email: auth.User.Email, - PartnerID: auth.User.PartnerID, - UserAgent: auth.User.UserAgent, - ProjectLimit: auth.User.ProjectLimit, - IsProfessional: auth.User.IsProfessional, - Position: auth.User.Position, - CompanyName: auth.User.CompanyName, - EmployeeCount: auth.User.EmployeeCount, - HaveSalesContact: auth.User.HaveSalesContact, - PaidTier: auth.User.PaidTier, - MFAEnabled: auth.User.MFAEnabled, - MFARecoveryCodeCount: len(auth.User.MFARecoveryCodes), + respUser := &ResponseUser{ + ID: user.ID, + FullName: user.FullName, + ShortName: user.ShortName, + Email: user.Email, + PartnerID: user.PartnerID, + UserAgent: user.UserAgent, + ProjectLimit: user.ProjectLimit, + IsProfessional: user.IsProfessional, + Position: user.Position, + CompanyName: user.CompanyName, + EmployeeCount: user.EmployeeCount, + HaveSalesContact: user.HaveSalesContact, + PaidTier: user.PaidTier, + MFAEnabled: user.MFAEnabled, + MFARecoveryCodeCount: len(user.MFARecoveryCodes), } - return user, api.HTTPError{} + return respUser, api.HTTPError{} } // GetUserID returns the User ID from the session. func (s *Service) GetUserID(ctx context.Context) (id uuid.UUID, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get user ID") + user, err := s.getUserAndAuditLog(ctx, "get user ID") if err != nil { return uuid.UUID{}, Error.Wrap(err) } - return auth.User.ID, nil + return user.ID, nil } // GetUserByEmailWithUnverified returns Users by email. @@ -1123,7 +1148,7 @@ func (s *Service) GetUserByEmailWithUnverified(ctx context.Context, email string // UpdateAccount updates User. func (s *Service) UpdateAccount(ctx context.Context, fullName string, shortName string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "update account") + user, err := s.getUserAndAuditLog(ctx, "update account") if err != nil { return Error.Wrap(err) } @@ -1134,9 +1159,9 @@ func (s *Service) UpdateAccount(ctx context.Context, fullName string, shortName return ErrValidation.Wrap(err) } - auth.User.FullName = fullName - auth.User.ShortName = shortName - err = s.store.Users().Update(ctx, &auth.User) + user.FullName = fullName + user.ShortName = shortName + err = s.store.Users().Update(ctx, user) if err != nil { return Error.Wrap(err) } @@ -1147,7 +1172,7 @@ func (s *Service) UpdateAccount(ctx context.Context, fullName string, shortName // ChangeEmail updates email for a given user. func (s *Service) ChangeEmail(ctx context.Context, newEmail string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "change email") + user, err := s.getUserAndAuditLog(ctx, "change email") if err != nil { return Error.Wrap(err) } @@ -1164,8 +1189,8 @@ func (s *Service) ChangeEmail(ctx context.Context, newEmail string) (err error) return ErrEmailUsed.New(emailUsedErrMsg) } - auth.User.Email = newEmail - err = s.store.Users().Update(ctx, &auth.User) + user.Email = newEmail + err = s.store.Users().Update(ctx, user) if err != nil { return Error.Wrap(err) } @@ -1176,12 +1201,12 @@ func (s *Service) ChangeEmail(ctx context.Context, newEmail string) (err error) // ChangePassword updates password for a given user. func (s *Service) ChangePassword(ctx context.Context, pass, newPass string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "change password") + user, err := s.getUserAndAuditLog(ctx, "change password") if err != nil { return Error.Wrap(err) } - err = bcrypt.CompareHashAndPassword(auth.User.PasswordHash, []byte(pass)) + err = bcrypt.CompareHashAndPassword(user.PasswordHash, []byte(pass)) if err != nil { return ErrUnauthorized.New(credentialsErrMsg) } @@ -1195,8 +1220,8 @@ func (s *Service) ChangePassword(ctx context.Context, pass, newPass string) (err return Error.Wrap(err) } - auth.User.PasswordHash = hash - err = s.store.Users().Update(ctx, &auth.User) + user.PasswordHash = hash + err = s.store.Users().Update(ctx, user) if err != nil { return Error.Wrap(err) } @@ -1207,12 +1232,12 @@ func (s *Service) ChangePassword(ctx context.Context, pass, newPass string) (err // DeleteAccount deletes User. func (s *Service) DeleteAccount(ctx context.Context, password string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "delete account") + user, err := s.getUserAndAuditLog(ctx, "delete account") if err != nil { return Error.Wrap(err) } - err = bcrypt.CompareHashAndPassword(auth.User.PasswordHash, []byte(password)) + err = bcrypt.CompareHashAndPassword(user.PasswordHash, []byte(password)) if err != nil { return ErrUnauthorized.New(credentialsErrMsg) } @@ -1222,7 +1247,7 @@ func (s *Service) DeleteAccount(ctx context.Context, password string) (err error return Error.Wrap(err) } - err = s.store.Users().Delete(ctx, auth.User.ID) + err = s.store.Users().Delete(ctx, user.ID) if err != nil { return Error.Wrap(err) } @@ -1233,12 +1258,12 @@ func (s *Service) DeleteAccount(ctx context.Context, password string) (err error // GetProject is a method for querying project by id. func (s *Service) GetProject(ctx context.Context, projectID uuid.UUID) (p *Project, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get project", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get project", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - if _, err = s.isProjectMember(ctx, auth.User.ID, projectID); err != nil { + if _, err = s.isProjectMember(ctx, user.ID, projectID); err != nil { return nil, Error.Wrap(err) } @@ -1253,12 +1278,12 @@ func (s *Service) GetProject(ctx context.Context, projectID uuid.UUID) (p *Proje // GetUsersProjects is a method for querying all projects. func (s *Service) GetUsersProjects(ctx context.Context) (ps []Project, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get users projects") + user, err := s.getUserAndAuditLog(ctx, "get users projects") if err != nil { return nil, Error.Wrap(err) } - ps, err = s.store.Projects().GetByUserID(ctx, auth.User.ID) + ps, err = s.store.Projects().GetByUserID(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -1271,7 +1296,7 @@ func (s *Service) GenGetUsersProjects(ctx context.Context) (ps []Project, httpEr var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get users projects") + user, err := s.getUserAndAuditLog(ctx, "get users projects") if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1279,7 +1304,7 @@ func (s *Service) GenGetUsersProjects(ctx context.Context) (ps []Project, httpEr } } - ps, err = s.store.Projects().GetByUserID(ctx, auth.User.ID) + ps, err = s.store.Projects().GetByUserID(ctx, user.ID) if err != nil { return nil, api.HTTPError{ Status: http.StatusInternalServerError, @@ -1293,12 +1318,12 @@ func (s *Service) GenGetUsersProjects(ctx context.Context) (ps []Project, httpEr // GetUsersOwnedProjectsPage is a method for querying paged projects. func (s *Service) GetUsersOwnedProjectsPage(ctx context.Context, cursor ProjectsCursor) (_ ProjectsPage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get user's owned projects page") + user, err := s.getUserAndAuditLog(ctx, "get user's owned projects page") if err != nil { return ProjectsPage{}, Error.Wrap(err) } - projects, err := s.store.Projects().ListByOwnerID(ctx, auth.User.ID, cursor) + projects, err := s.store.Projects().ListByOwnerID(ctx, user.ID, cursor) if err != nil { return ProjectsPage{}, Error.Wrap(err) } @@ -1309,17 +1334,17 @@ func (s *Service) GetUsersOwnedProjectsPage(ctx context.Context, cursor Projects // CreateProject is a method for creating new project. func (s *Service) CreateProject(ctx context.Context, projectInfo ProjectInfo) (p *Project, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "create project") + user, err := s.getUserAndAuditLog(ctx, "create project") if err != nil { return nil, Error.Wrap(err) } - currentProjectCount, err := s.checkProjectLimit(ctx, auth.User.ID) + currentProjectCount, err := s.checkProjectLimit(ctx, user.ID) if err != nil { return nil, ErrProjLimit.Wrap(err) } - newProjectLimits, err := s.getUserProjectLimits(ctx, auth.User.ID) + newProjectLimits, err := s.getUserProjectLimits(ctx, user.ID) if err != nil { return nil, ErrProjLimit.Wrap(err) } @@ -1330,9 +1355,9 @@ func (s *Service) CreateProject(ctx context.Context, projectInfo ProjectInfo) (p &Project{ Description: projectInfo.Description, Name: projectInfo.Name, - OwnerID: auth.User.ID, - PartnerID: auth.User.PartnerID, - UserAgent: auth.User.UserAgent, + OwnerID: user.ID, + PartnerID: user.PartnerID, + UserAgent: user.UserAgent, StorageLimit: &newProjectLimits.StorageLimit, BandwidthLimit: &newProjectLimits.BandwidthLimit, SegmentLimit: &newProjectLimits.SegmentLimit, @@ -1342,7 +1367,7 @@ func (s *Service) CreateProject(ctx context.Context, projectInfo ProjectInfo) (p return Error.Wrap(err) } - _, err = tx.ProjectMembers().Insert(ctx, auth.User.ID, p.ID) + _, err = tx.ProjectMembers().Insert(ctx, user.ID, p.ID) if err != nil { return Error.Wrap(err) } @@ -1356,7 +1381,7 @@ func (s *Service) CreateProject(ctx context.Context, projectInfo ProjectInfo) (p return nil, Error.Wrap(err) } - s.analytics.TrackProjectCreated(auth.User.ID, auth.User.Email, projectID, currentProjectCount+1) + s.analytics.TrackProjectCreated(user.ID, user.Email, projectID, currentProjectCount+1) return p, nil } @@ -1366,7 +1391,7 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "create project") + user, err := s.getUserAndAuditLog(ctx, "create project") if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1374,7 +1399,7 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) } } - currentProjectCount, err := s.checkProjectLimit(ctx, auth.User.ID) + currentProjectCount, err := s.checkProjectLimit(ctx, user.ID) if err != nil { return nil, api.HTTPError{ Status: http.StatusInternalServerError, @@ -1382,7 +1407,7 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) } } - newProjectLimits, err := s.getUserProjectLimits(ctx, auth.User.ID) + newProjectLimits, err := s.getUserProjectLimits(ctx, user.ID) if err != nil { return nil, api.HTTPError{ Status: http.StatusInternalServerError, @@ -1396,9 +1421,9 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) &Project{ Description: projectInfo.Description, Name: projectInfo.Name, - OwnerID: auth.User.ID, - PartnerID: auth.User.PartnerID, - UserAgent: auth.User.UserAgent, + OwnerID: user.ID, + PartnerID: user.PartnerID, + UserAgent: user.UserAgent, StorageLimit: &newProjectLimits.StorageLimit, BandwidthLimit: &newProjectLimits.BandwidthLimit, SegmentLimit: &newProjectLimits.SegmentLimit, @@ -1408,7 +1433,7 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) return Error.Wrap(err) } - _, err = tx.ProjectMembers().Insert(ctx, auth.User.ID, p.ID) + _, err = tx.ProjectMembers().Insert(ctx, user.ID, p.ID) if err != nil { return Error.Wrap(err) } @@ -1425,7 +1450,7 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) } } - s.analytics.TrackProjectCreated(auth.User.ID, auth.User.Email, projectID, currentProjectCount+1) + s.analytics.TrackProjectCreated(user.ID, user.Email, projectID, currentProjectCount+1) return p, httpError } @@ -1434,17 +1459,17 @@ func (s *Service) GenCreateProject(ctx context.Context, projectInfo ProjectInfo) func (s *Service) DeleteProject(ctx context.Context, projectID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "delete project", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "delete project", zap.String("projectID", projectID.String())) if err != nil { return Error.Wrap(err) } - _, err = s.isProjectOwner(ctx, auth.User.ID, projectID) + _, err = s.isProjectOwner(ctx, user.ID, projectID) if err != nil { return Error.Wrap(err) } - err = s.checkProjectCanBeDeleted(ctx, auth.User, projectID) + err = s.checkProjectCanBeDeleted(ctx, user, projectID) if err != nil { return Error.Wrap(err) } @@ -1462,7 +1487,7 @@ func (s *Service) GenDeleteProject(ctx context.Context, projectID uuid.UUID) (ht var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "delete project", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "delete project", zap.String("projectID", projectID.String())) if err != nil { return api.HTTPError{ Status: http.StatusUnauthorized, @@ -1470,7 +1495,7 @@ func (s *Service) GenDeleteProject(ctx context.Context, projectID uuid.UUID) (ht } } - _, err = s.isProjectOwner(ctx, auth.User.ID, projectID) + _, err = s.isProjectOwner(ctx, user.ID, projectID) if err != nil { return api.HTTPError{ Status: http.StatusUnauthorized, @@ -1478,7 +1503,7 @@ func (s *Service) GenDeleteProject(ctx context.Context, projectID uuid.UUID) (ht } } - err = s.checkProjectCanBeDeleted(ctx, auth.User, projectID) + err = s.checkProjectCanBeDeleted(ctx, user, projectID) if err != nil { return api.HTTPError{ Status: http.StatusConflict, @@ -1501,7 +1526,7 @@ func (s *Service) GenDeleteProject(ctx context.Context, projectID uuid.UUID) (ht func (s *Service) UpdateProject(ctx context.Context, projectID uuid.UUID, projectInfo ProjectInfo) (p *Project, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "update project name and description", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "update project name and description", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } @@ -1511,7 +1536,7 @@ func (s *Service) UpdateProject(ctx context.Context, projectID uuid.UUID, projec return nil, Error.Wrap(err) } - isMember, err := s.isProjectMember(ctx, auth.User.ID, projectID) + isMember, err := s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -1519,7 +1544,7 @@ func (s *Service) UpdateProject(ctx context.Context, projectID uuid.UUID, projec project.Name = projectInfo.Name project.Description = projectInfo.Description - if auth.User.PaidTier { + if user.PaidTier { if project.BandwidthLimit != nil && *project.BandwidthLimit == 0 { return nil, Error.New("current bandwidth limit for project is set to 0 (updating disabled)") } @@ -1573,7 +1598,7 @@ func (s *Service) GenUpdateProject(ctx context.Context, projectID uuid.UUID, pro var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "update project name and description", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "update project name and description", zap.String("projectID", projectID.String())) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1589,7 +1614,7 @@ func (s *Service) GenUpdateProject(ctx context.Context, projectID uuid.UUID, pro } } - isMember, err := s.isProjectMember(ctx, auth.User.ID, projectID) + isMember, err := s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1600,7 +1625,7 @@ func (s *Service) GenUpdateProject(ctx context.Context, projectID uuid.UUID, pro project.Name = projectInfo.Name project.Description = projectInfo.Description - if auth.User.PaidTier { + if user.PaidTier { if project.BandwidthLimit != nil && *project.BandwidthLimit == 0 { return nil, api.HTTPError{ Status: http.StatusInternalServerError, @@ -1682,12 +1707,12 @@ func (s *Service) GenUpdateProject(ctx context.Context, projectID uuid.UUID, pro // AddProjectMembers adds users by email to given project. func (s *Service) AddProjectMembers(ctx context.Context, projectID uuid.UUID, emails []string) (users []*User, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "add project members", zap.String("projectID", projectID.String()), zap.Strings("emails", emails)) + user, err := s.getUserAndAuditLog(ctx, "add project members", zap.String("projectID", projectID.String()), zap.Strings("emails", emails)) if err != nil { return nil, Error.Wrap(err) } - if _, err = s.isProjectMember(ctx, auth.User.ID, projectID); err != nil { + if _, err = s.isProjectMember(ctx, user.ID, projectID); err != nil { return nil, Error.Wrap(err) } @@ -1727,12 +1752,12 @@ func (s *Service) AddProjectMembers(ctx context.Context, projectID uuid.UUID, em // DeleteProjectMembers removes users by email from given project. func (s *Service) DeleteProjectMembers(ctx context.Context, projectID uuid.UUID, emails []string) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "delete project members", zap.String("projectID", projectID.String()), zap.Strings("emails", emails)) + user, err := s.getUserAndAuditLog(ctx, "delete project members", zap.String("projectID", projectID.String()), zap.Strings("emails", emails)) if err != nil { return Error.Wrap(err) } - if _, err = s.isProjectMember(ctx, auth.User.ID, projectID); err != nil { + if _, err = s.isProjectMember(ctx, user.ID, projectID); err != nil { return Error.Wrap(err) } @@ -1779,12 +1804,12 @@ func (s *Service) DeleteProjectMembers(ctx context.Context, projectID uuid.UUID, func (s *Service) GetProjectMembers(ctx context.Context, projectID uuid.UUID, cursor ProjectMembersCursor) (pmp *ProjectMembersPage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get project members", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get project members", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -1805,12 +1830,12 @@ func (s *Service) GetProjectMembers(ctx context.Context, projectID uuid.UUID, cu func (s *Service) CreateAPIKey(ctx context.Context, projectID uuid.UUID, name string) (_ *APIKeyInfo, _ *macaroon.APIKey, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "create api key", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "create api key", zap.String("projectID", projectID.String())) if err != nil { return nil, nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, nil, Error.Wrap(err) } @@ -1834,8 +1859,8 @@ func (s *Service) CreateAPIKey(ctx context.Context, projectID uuid.UUID, name st Name: name, ProjectID: projectID, Secret: secret, - PartnerID: auth.User.PartnerID, - UserAgent: auth.User.UserAgent, + PartnerID: user.PartnerID, + UserAgent: user.UserAgent, } info, err := s.store.APIKeys().Create(ctx, key.Head(), apikey) @@ -1843,7 +1868,7 @@ func (s *Service) CreateAPIKey(ctx context.Context, projectID uuid.UUID, name st return nil, nil, Error.Wrap(err) } - s.analytics.TrackAccessGrantCreated(auth.User.ID, auth.User.Email) + s.analytics.TrackAccessGrantCreated(user.ID, user.Email) return info, key, nil } @@ -1853,7 +1878,7 @@ func (s *Service) GenCreateAPIKey(ctx context.Context, requestInfo CreateAPIKeyR var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "create api key", zap.String("projectID", requestInfo.ProjectID)) + user, err := s.getUserAndAuditLog(ctx, "create api key", zap.String("projectID", requestInfo.ProjectID)) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1869,7 +1894,7 @@ func (s *Service) GenCreateAPIKey(ctx context.Context, requestInfo CreateAPIKeyR } } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -1905,8 +1930,8 @@ func (s *Service) GenCreateAPIKey(ctx context.Context, requestInfo CreateAPIKeyR Name: requestInfo.Name, ProjectID: projectID, Secret: secret, - PartnerID: auth.User.PartnerID, - UserAgent: auth.User.UserAgent, + PartnerID: user.PartnerID, + UserAgent: user.UserAgent, } info, err := s.store.APIKeys().Create(ctx, key.Head(), apikey) @@ -1917,7 +1942,7 @@ func (s *Service) GenCreateAPIKey(ctx context.Context, requestInfo CreateAPIKeyR } } - s.analytics.TrackAccessGrantCreated(auth.User.ID, auth.User.Email) + s.analytics.TrackAccessGrantCreated(user.ID, user.Email) return &CreateAPIKeyResponse{ Key: key.Serialize(), @@ -1929,7 +1954,7 @@ func (s *Service) GenCreateAPIKey(ctx context.Context, requestInfo CreateAPIKeyR func (s *Service) GetAPIKeyInfoByName(ctx context.Context, projectID uuid.UUID, name string) (_ *APIKeyInfo, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get api key info", + user, err := s.getUserAndAuditLog(ctx, "get api key info", zap.String("projectID", projectID.String()), zap.String("name", name)) @@ -1942,7 +1967,7 @@ func (s *Service) GetAPIKeyInfoByName(ctx context.Context, projectID uuid.UUID, return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, key.ProjectID) + _, err = s.isProjectMember(ctx, user.ID, key.ProjectID) if err != nil { return nil, Error.Wrap(err) } @@ -1954,7 +1979,7 @@ func (s *Service) GetAPIKeyInfoByName(ctx context.Context, projectID uuid.UUID, func (s *Service) GetAPIKeyInfo(ctx context.Context, id uuid.UUID) (_ *APIKeyInfo, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get api key info", zap.String("apiKeyID", id.String())) + user, err := s.getUserAndAuditLog(ctx, "get api key info", zap.String("apiKeyID", id.String())) if err != nil { return nil, err } @@ -1964,7 +1989,7 @@ func (s *Service) GetAPIKeyInfo(ctx context.Context, id uuid.UUID) (_ *APIKeyInf return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, key.ProjectID) + _, err = s.isProjectMember(ctx, user.ID, key.ProjectID) if err != nil { return nil, Error.Wrap(err) } @@ -1981,7 +2006,7 @@ func (s *Service) DeleteAPIKeys(ctx context.Context, ids []uuid.UUID) (err error idStrings = append(idStrings, id.String()) } - auth, err := s.getAuthAndAuditLog(ctx, "delete api keys", zap.Strings("apiKeyIDs", idStrings)) + user, err := s.getUserAndAuditLog(ctx, "delete api keys", zap.Strings("apiKeyIDs", idStrings)) if err != nil { return Error.Wrap(err) } @@ -1995,7 +2020,7 @@ func (s *Service) DeleteAPIKeys(ctx context.Context, ids []uuid.UUID) (err error continue } - _, err = s.isProjectMember(ctx, auth.User.ID, key.ProjectID) + _, err = s.isProjectMember(ctx, user.ID, key.ProjectID) if err != nil { keysErr.Add(ErrUnauthorized.Wrap(err)) continue @@ -2023,12 +2048,12 @@ func (s *Service) DeleteAPIKeys(ctx context.Context, ids []uuid.UUID) (err error func (s *Service) DeleteAPIKeyByNameAndProjectID(ctx context.Context, name string, projectID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "delete api key by name and project ID", zap.String("apiKeyName", name), zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "delete api key by name and project ID", zap.String("apiKeyName", name), zap.String("projectID", projectID.String())) if err != nil { return Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return Error.Wrap(err) } @@ -2050,12 +2075,12 @@ func (s *Service) DeleteAPIKeyByNameAndProjectID(ctx context.Context, name strin func (s *Service) GetAPIKeys(ctx context.Context, projectID uuid.UUID, cursor APIKeyCursor) (page *APIKeyPage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get api keys", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get api keys", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2076,12 +2101,12 @@ func (s *Service) GetAPIKeys(ctx context.Context, projectID uuid.UUID, cursor AP func (s *Service) CreateRESTKey(ctx context.Context, expiration time.Duration) (apiKey string, expiresAt time.Time, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "create rest key") + user, err := s.getUserAndAuditLog(ctx, "create rest key") if err != nil { return "", time.Time{}, Error.Wrap(err) } - apiKey, expiresAt, err = s.restKeys.Create(ctx, auth.User.ID, expiration) + apiKey, expiresAt, err = s.restKeys.Create(ctx, user.ID, expiration) if err != nil { return "", time.Time{}, Error.Wrap(err) } @@ -2092,7 +2117,7 @@ func (s *Service) CreateRESTKey(ctx context.Context, expiration time.Duration) ( func (s *Service) RevokeRESTKey(ctx context.Context, apiKey string) (err error) { defer mon.Task()(&ctx)(&err) - _, err = s.getAuthAndAuditLog(ctx, "revoke rest key") + _, err = s.getUserAndAuditLog(ctx, "revoke rest key") if err != nil { return Error.Wrap(err) } @@ -2108,12 +2133,12 @@ func (s *Service) RevokeRESTKey(ctx context.Context, apiKey string) (err error) func (s *Service) GetProjectUsage(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ *accounting.ProjectUsage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get project usage", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get project usage", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2130,12 +2155,12 @@ func (s *Service) GetProjectUsage(ctx context.Context, projectID uuid.UUID, sinc func (s *Service) GetBucketTotals(ctx context.Context, projectID uuid.UUID, cursor accounting.BucketUsageCursor, before time.Time) (_ *accounting.BucketUsagePage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get bucket totals", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get bucket totals", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2152,12 +2177,12 @@ func (s *Service) GetBucketTotals(ctx context.Context, projectID uuid.UUID, curs func (s *Service) GetAllBucketNames(ctx context.Context, projectID uuid.UUID) (_ []string, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get all bucket names", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get all bucket names", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2187,12 +2212,12 @@ func (s *Service) GetAllBucketNames(ctx context.Context, projectID uuid.UUID) (_ func (s *Service) GetBucketUsageRollups(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ []accounting.BucketUsageRollup, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get bucket usage rollups", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get bucket usage rollups", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2210,7 +2235,7 @@ func (s *Service) GenGetBucketUsageRollups(ctx context.Context, projectID uuid.U var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get bucket usage rollups", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get bucket usage rollups", zap.String("projectID", projectID.String())) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -2218,7 +2243,7 @@ func (s *Service) GenGetBucketUsageRollups(ctx context.Context, projectID uuid.U } } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -2242,7 +2267,7 @@ func (s *Service) GenGetSingleBucketUsageRollup(ctx context.Context, projectID u var err error defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get single bucket usage rollup", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get single bucket usage rollup", zap.String("projectID", projectID.String())) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -2250,7 +2275,7 @@ func (s *Service) GenGetSingleBucketUsageRollup(ctx context.Context, projectID u } } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, api.HTTPError{ Status: http.StatusUnauthorized, @@ -2273,12 +2298,12 @@ func (s *Service) GenGetSingleBucketUsageRollup(ctx context.Context, projectID u func (s *Service) GetDailyProjectUsage(ctx context.Context, projectID uuid.UUID, from, to time.Time) (_ *accounting.ProjectDailyUsage, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get daily usage by project ID") + user, err := s.getUserAndAuditLog(ctx, "get daily usage by project ID") if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2298,12 +2323,12 @@ func (s *Service) GetDailyProjectUsage(ctx context.Context, projectID uuid.UUID, func (s *Service) GetProjectUsageLimits(ctx context.Context, projectID uuid.UUID) (_ *ProjectUsageLimits, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get project usage limits", zap.String("projectID", projectID.String())) + user, err := s.getUserAndAuditLog(ctx, "get project usage limits", zap.String("projectID", projectID.String())) if err != nil { return nil, Error.Wrap(err) } - _, err = s.isProjectMember(ctx, auth.User.ID, projectID) + _, err = s.isProjectMember(ctx, user.ID, projectID) if err != nil { return nil, Error.Wrap(err) } @@ -2332,12 +2357,12 @@ func (s *Service) GetProjectUsageLimits(ctx context.Context, projectID uuid.UUID func (s *Service) GetTotalUsageLimits(ctx context.Context) (_ *ProjectUsageLimits, err error) { defer mon.Task()(&ctx)(&err) - auth, err := s.getAuthAndAuditLog(ctx, "get total usage and limits for all the projects") + user, err := s.getUserAndAuditLog(ctx, "get total usage and limits for all the projects") if err != nil { return nil, Error.Wrap(err) } - projects, err := s.store.Projects().GetOwn(ctx, auth.User.ID) + projects, err := s.store.Projects().GetOwn(ctx, user.ID) if err != nil { return nil, Error.Wrap(err) } @@ -2396,86 +2421,43 @@ func (s *Service) getProjectUsageLimits(ctx context.Context, projectID uuid.UUID }, nil } -// Authorize validates token from context and returns authorized Authorization. -func (s *Service) Authorize(ctx context.Context) (a Authorization, err error) { +// TokenAuth returns an authenticated context by session token. +func (s *Service) TokenAuth(ctx context.Context, token consoleauth.Token, authTime time.Time) (_ context.Context, err error) { defer mon.Task()(&ctx)(&err) - tokenS, ok := consoleauth.GetAPIKey(ctx) - if !ok { - return Authorization{}, ErrUnauthorized.New("no api key was provided") - } - token, err := consoleauth.FromBase64URLString(string(tokenS)) + valid, err := s.tokens.ValidateToken(token) if err != nil { - return Authorization{}, ErrUnauthorized.Wrap(err) + return nil, Error.Wrap(err) + } + if !valid { + return nil, Error.New("incorrect signature") } - claims, err := s.authenticate(ctx, token) + sessionID, err := uuid.FromBytes(token.Payload) if err != nil { - return Authorization{}, ErrUnauthorized.Wrap(err) + return nil, Error.Wrap(err) } - user, err := s.authorize(ctx, claims) + session, err := s.store.WebappSessions().GetBySessionID(ctx, sessionID) if err != nil { - return Authorization{}, ErrUnauthorized.Wrap(err) + return nil, Error.Wrap(err) } - return Authorization{ - User: *user, - Claims: *claims, - }, nil -} - -// IsAuthenticated checks if request has authorization credentials. -func (s *Service) IsAuthenticated(ctx context.Context, r *http.Request, isCookieAuth, isKeyAuth bool) (context.Context, error) { - var err error - - if isCookieAuth && isKeyAuth { - ctx, err = s.cookieAuth(ctx, r) + ctx, err = s.authorize(ctx, session.UserID, session.ExpiresAt, authTime) + if err != nil { + err := errs.Combine(err, s.store.WebappSessions().DeleteBySessionID(ctx, sessionID)) if err != nil { - ctx, err = s.keyAuth(ctx, r) - if err != nil { - return nil, err - } - } - } else if isCookieAuth { - ctx, err = s.cookieAuth(ctx, r) - if err != nil { - return nil, err - } - } else if isKeyAuth { - ctx, err = s.keyAuth(ctx, r) - if err != nil { - return nil, err + return nil, Error.Wrap(err) } + return nil, err } return ctx, nil } -// cookieAuth checks if request has an authorization cookie. -func (s *Service) cookieAuth(ctx context.Context, r *http.Request) (context.Context, error) { - cookie, err := r.Cookie("_tokenKey") - if err != nil { - return ctx, err - } - - auth, err := s.Authorize(consoleauth.WithAPIKey(ctx, []byte(cookie.Value))) - if err != nil { - return ctx, err - } - - return WithAuth(ctx, auth), nil -} - -// keyAuth checks if request has an authorization api key. -func (s *Service) keyAuth(ctx context.Context, r *http.Request) (context.Context, error) { - authToken := r.Header.Get("Authorization") - split := strings.Split(authToken, "Bearer ") - if len(split) != 2 { - return nil, errs.New("authorization key format is incorrect. Should be 'Bearer '") - } - - apikey := split[1] +// KeyAuth returns an authenticated context by api key. +func (s *Service) KeyAuth(ctx context.Context, apikey string, authTime time.Time) (_ context.Context, err error) { + defer mon.Task()(&ctx)(&err) ctx = consoleauth.WithAPIKey(ctx, []byte(apikey)) @@ -2484,28 +2466,17 @@ func (s *Service) keyAuth(ctx context.Context, r *http.Request) (context.Context return nil, err } - claims := &consoleauth.Claims{ - ID: userID, - Email: "", - Expiration: exp, - } - - user, err := s.authorize(ctx, claims) + ctx, err = s.authorize(ctx, userID, exp, authTime) if err != nil { return nil, err } - auth := Authorization{ - User: *user, - Claims: *claims, - } - - return WithAuth(ctx, auth), nil + return ctx, nil } // checkProjectCanBeDeleted ensures that all data, api-keys and buckets are deleted and usage has been accounted. // no error means the project status is clean. -func (s *Service) checkProjectCanBeDeleted(ctx context.Context, user User, projectID uuid.UUID) (err error) { +func (s *Service) checkProjectCanBeDeleted(ctx context.Context, user *User, projectID uuid.UUID) (err error) { defer mon.Task()(&ctx)(&err) buckets, err := s.buckets.CountBuckets(ctx, projectID) @@ -2587,44 +2558,22 @@ func (s *Service) CreateRegToken(ctx context.Context, projLimit int) (_ *Registr return result, nil } -// authenticate validates token signature and returns authenticated *satelliteauth.Authorization. -func (s *Service) authenticate(ctx context.Context, token consoleauth.Token) (_ *consoleauth.Claims, err error) { +// authorize returns an authorized context by user ID. +func (s *Service) authorize(ctx context.Context, userID uuid.UUID, expiration time.Time, authTime time.Time) (_ context.Context, err error) { defer mon.Task()(&ctx)(&err) - signature := token.Signature + if !expiration.IsZero() && expiration.Before(authTime) { + return nil, ErrTokenExpiration.New("authorization failed. expiration reached.") + } - err = s.tokens.SignToken(&token) + user, err := s.store.Users().Get(ctx, userID) if err != nil { - return nil, Error.Wrap(err) - } - - if subtle.ConstantTimeCompare(signature, token.Signature) != 1 { - return nil, Error.New("incorrect signature") - } - - claims, err := consoleauth.FromJSON(token.Payload) - if err != nil { - return nil, Error.Wrap(err) - } - - return claims, nil -} - -// authorize checks claims and returns authorized User. -func (s *Service) authorize(ctx context.Context, claims *consoleauth.Claims) (_ *User, err error) { - defer mon.Task()(&ctx)(&err) - if !claims.Expiration.IsZero() && claims.Expiration.Before(time.Now()) { - return nil, ErrTokenExpiration.New("") - } - - user, err := s.store.Users().Get(ctx, claims.ID) - if err != nil { - return nil, ErrValidation.New("authorization failed. no user with id: %s", claims.ID.String()) + return nil, Error.New("authorization failed. no user with id: %s", userID.String()) } if user.Status != Active { - return nil, ErrValidation.New("authorization failed. no active user with id: %s", claims.ID.String()) + return nil, Error.New("authorization failed. no active user with id: %s", userID.String()) } - return user, nil + return WithUser(ctx, user), nil } // isProjectMember is return type of isProjectMember service method. @@ -2701,11 +2650,11 @@ var ErrWalletNotClaimed = errs.Class("wallet is not claimed") func (payment Payments) ClaimWallet(ctx context.Context) (_ WalletInfo, err error) { defer mon.Task()(&ctx)(&err) - auth, err := payment.service.getAuthAndAuditLog(ctx, "claim wallet") + user, err := payment.service.getUserAndAuditLog(ctx, "claim wallet") if err != nil { return WalletInfo{}, Error.Wrap(err) } - address, err := payment.service.depositWallets.Claim(ctx, auth.User.ID) + address, err := payment.service.depositWallets.Claim(ctx, user.ID) if err != nil { return WalletInfo{}, Error.Wrap(err) } @@ -2719,11 +2668,11 @@ func (payment Payments) ClaimWallet(ctx context.Context) (_ WalletInfo, err erro func (payment Payments) GetWallet(ctx context.Context) (_ WalletInfo, err error) { defer mon.Task()(&ctx)(&err) - auth, err := GetAuth(ctx) + user, err := GetUser(ctx) if err != nil { return WalletInfo{}, Error.Wrap(err) } - address, err := payment.service.depositWallets.Get(ctx, auth.User.ID) + address, err := payment.service.depositWallets.Get(ctx, user.ID) if err != nil { return WalletInfo{}, Error.Wrap(err) } @@ -2746,3 +2695,23 @@ func findMembershipByProjectID(memberships []ProjectMember, projectID uuid.UUID) } return ProjectMember{}, false } + +// DeleteSessionByToken removes the session corresponding to the given token from the database. +func (s *Service) DeleteSessionByToken(ctx context.Context, token consoleauth.Token) (err error) { + defer mon.Task()(&ctx)(&err) + + valid, err := s.tokens.ValidateToken(token) + if err != nil { + return err + } + if !valid { + return ErrValidation.New("Invalid session token.") + } + + id, err := uuid.FromBytes(token.Payload) + if err != nil { + return err + } + + return s.store.WebappSessions().DeleteBySessionID(ctx, id) +} diff --git a/satellite/console/service_test.go b/satellite/console/service_test.go index 1be9b1da2..7d56ffa2b 100644 --- a/satellite/console/service_test.go +++ b/satellite/console/service_test.go @@ -5,6 +5,7 @@ package console_test import ( "context" + "database/sql" "encoding/json" "math" "math/big" @@ -24,7 +25,6 @@ import ( "storj.io/storj/private/testplanet" "storj.io/storj/satellite" "storj.io/storj/satellite/console" - "storj.io/storj/satellite/console/consoleauth" ) func TestService(t *testing.T) { @@ -46,20 +46,20 @@ func TestService(t *testing.T) { require.NotEqual(t, up1Pro1.ID, up2Pro1.ID) require.NotEqual(t, up1Pro1.OwnerID, up2Pro1.OwnerID) - authCtx1, err := sat.AuthenticatedContext(ctx, up1Pro1.OwnerID) + userCtx1, err := sat.UserContext(ctx, up1Pro1.OwnerID) require.NoError(t, err) - authCtx2, err := sat.AuthenticatedContext(ctx, up2Pro1.OwnerID) + userCtx2, err := sat.UserContext(ctx, up2Pro1.OwnerID) require.NoError(t, err) t.Run("TestGetProject", func(t *testing.T) { // Getting own project details should work - project, err := service.GetProject(authCtx1, up1Pro1.ID) + project, err := service.GetProject(userCtx1, up1Pro1.ID) require.NoError(t, err) require.Equal(t, up1Pro1.ID, project.ID) // Getting someone else project details should not work - project, err = service.GetProject(authCtx1, up2Pro1.ID) + project, err = service.GetProject(userCtx1, up2Pro1.ID) require.Error(t, err) require.Nil(t, project) }) @@ -75,17 +75,17 @@ func TestService(t *testing.T) { require.NoError(t, err) require.False(t, user.PaidTier) // get context - authCtx1, err := sat.AuthenticatedContext(ctx, user.ID) + userCtx1, err := sat.UserContext(ctx, user.ID) require.NoError(t, err) // add a credit card to put the user in the paid tier - err = service.Payments().AddCreditCard(authCtx1, "test-cc-token") + err = service.Payments().AddCreditCard(userCtx1, "test-cc-token") require.NoError(t, err) // update auth ctx - authCtx1, err = sat.AuthenticatedContext(ctx, user.ID) + userCtx1, err = sat.UserContext(ctx, user.ID) require.NoError(t, err) // Updating own project should work - updatedProject, err := service.UpdateProject(authCtx1, up1Pro1.ID, console.ProjectInfo{ + updatedProject, err := service.UpdateProject(userCtx1, up1Pro1.ID, console.ProjectInfo{ Name: updatedName, Description: updatedDescription, StorageLimit: updatedStorageLimit, @@ -102,7 +102,7 @@ func TestService(t *testing.T) { require.Equal(t, updatedBandwidthLimit, *updatedProject.BandwidthLimit) // Updating someone else project details should not work - updatedProject, err = service.UpdateProject(authCtx1, up2Pro1.ID, console.ProjectInfo{ + updatedProject, err = service.UpdateProject(userCtx1, up2Pro1.ID, console.ProjectInfo{ Name: "newName", Description: "TestUpdate", StorageLimit: memory.Size(100), @@ -127,7 +127,7 @@ func TestService(t *testing.T) { StorageLimit: memory.Size(123), BandwidthLimit: memory.Size(123), } - updatedProject, err = service.UpdateProject(authCtx1, up1Pro1.ID, updateInfo) + updatedProject, err = service.UpdateProject(userCtx1, up1Pro1.ID, updateInfo) require.Error(t, err) require.Nil(t, updatedProject) @@ -137,7 +137,7 @@ func TestService(t *testing.T) { err = sat.DB.Console().Projects().Update(ctx, up1Pro1) require.NoError(t, err) - updatedProject, err = service.UpdateProject(authCtx1, up1Pro1.ID, updateInfo) + updatedProject, err = service.UpdateProject(userCtx1, up1Pro1.ID, updateInfo) require.Error(t, err) require.Nil(t, updatedProject) @@ -146,7 +146,7 @@ func TestService(t *testing.T) { err = sat.DB.Console().Projects().Update(ctx, up1Pro1) require.NoError(t, err) - updatedProject, err = service.UpdateProject(authCtx1, up1Pro1.ID, updateInfo) + updatedProject, err = service.UpdateProject(userCtx1, up1Pro1.ID, updateInfo) require.NoError(t, err) require.Equal(t, updateInfo.Name, updatedProject.Name) require.Equal(t, updateInfo.Description, updatedProject.Description) @@ -155,7 +155,7 @@ func TestService(t *testing.T) { require.Equal(t, updateInfo.StorageLimit, *updatedProject.StorageLimit) require.Equal(t, updateInfo.BandwidthLimit, *updatedProject.BandwidthLimit) - project, err := service.GetProject(authCtx1, up1Pro1.ID) + project, err := service.GetProject(userCtx1, up1Pro1.ID) require.NoError(t, err) require.Equal(t, updateInfo.StorageLimit, *project.StorageLimit) require.Equal(t, updateInfo.BandwidthLimit, *project.BandwidthLimit) @@ -163,69 +163,69 @@ func TestService(t *testing.T) { t.Run("TestAddProjectMembers", func(t *testing.T) { // Adding members to own project should work - addedUsers, err := service.AddProjectMembers(authCtx1, up1Pro1.ID, []string{up2User.Email}) + addedUsers, err := service.AddProjectMembers(userCtx1, up1Pro1.ID, []string{up2User.Email}) require.NoError(t, err) require.Len(t, addedUsers, 1) require.Contains(t, addedUsers, up2User) // Adding members to someone else project should not work - addedUsers, err = service.AddProjectMembers(authCtx1, up2Pro1.ID, []string{up2User.Email}) + addedUsers, err = service.AddProjectMembers(userCtx1, up2Pro1.ID, []string{up2User.Email}) require.Error(t, err) require.Nil(t, addedUsers) }) t.Run("TestGetProjectMembers", func(t *testing.T) { // Getting the project members of an own project that one is a part of should work - userPage, err := service.GetProjectMembers(authCtx1, up1Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) + userPage, err := service.GetProjectMembers(userCtx1, up1Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) require.NoError(t, err) require.Len(t, userPage.ProjectMembers, 2) // Getting the project members of a foreign project that one is a part of should work - userPage, err = service.GetProjectMembers(authCtx2, up1Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) + userPage, err = service.GetProjectMembers(userCtx2, up1Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) require.NoError(t, err) require.Len(t, userPage.ProjectMembers, 2) // Getting the project members of a foreign project that one is not a part of should not work - userPage, err = service.GetProjectMembers(authCtx1, up2Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) + userPage, err = service.GetProjectMembers(userCtx1, up2Pro1.ID, console.ProjectMembersCursor{Page: 1, Limit: 10}) require.Error(t, err) require.Nil(t, userPage) }) t.Run("TestDeleteProjectMembers", func(t *testing.T) { // Deleting project members of an own project should work - err := service.DeleteProjectMembers(authCtx1, up1Pro1.ID, []string{up2User.Email}) + err := service.DeleteProjectMembers(userCtx1, up1Pro1.ID, []string{up2User.Email}) require.NoError(t, err) // Deleting Project members of someone else project should not work - err = service.DeleteProjectMembers(authCtx1, up2Pro1.ID, []string{up2User.Email}) + err = service.DeleteProjectMembers(userCtx1, up2Pro1.ID, []string{up2User.Email}) require.Error(t, err) }) t.Run("TestDeleteProject", func(t *testing.T) { // Deleting the own project should not work before deleting the API-Key - err := service.DeleteProject(authCtx1, up1Pro1.ID) + err := service.DeleteProject(userCtx1, up1Pro1.ID) require.Error(t, err) - keys, err := service.GetAPIKeys(authCtx1, up1Pro1.ID, console.APIKeyCursor{Page: 1, Limit: 10}) + keys, err := service.GetAPIKeys(userCtx1, up1Pro1.ID, console.APIKeyCursor{Page: 1, Limit: 10}) require.NoError(t, err) require.Len(t, keys.APIKeys, 1) - err = service.DeleteAPIKeys(authCtx1, []uuid.UUID{keys.APIKeys[0].ID}) + err = service.DeleteAPIKeys(userCtx1, []uuid.UUID{keys.APIKeys[0].ID}) require.NoError(t, err) // Deleting the own project should now work - err = service.DeleteProject(authCtx1, up1Pro1.ID) + err = service.DeleteProject(userCtx1, up1Pro1.ID) require.NoError(t, err) // Deleting someone else project should not work - err = service.DeleteProject(authCtx1, up2Pro1.ID) + err = service.DeleteProject(userCtx1, up2Pro1.ID) require.Error(t, err) err = planet.Uplinks[1].CreateBucket(ctx, sat, "testbucket") require.NoError(t, err) // deleting a project with a bucket should fail - err = service.DeleteProject(authCtx2, up2Pro1.ID) + err = service.DeleteProject(userCtx2, up2Pro1.ID) require.Error(t, err) require.Equal(t, "console service: project usage: some buckets still exist", err.Error()) }) @@ -233,14 +233,14 @@ func TestService(t *testing.T) { t.Run("TestChangeEmail", func(t *testing.T) { const newEmail = "newEmail@example.com" - err = service.ChangeEmail(authCtx2, newEmail) + err = service.ChangeEmail(userCtx2, newEmail) require.NoError(t, err) - user, _, err := service.GetUserByEmailWithUnverified(authCtx2, newEmail) + user, _, err := service.GetUserByEmailWithUnverified(userCtx2, newEmail) require.NoError(t, err) require.Equal(t, newEmail, user.Email) - err = service.ChangeEmail(authCtx2, newEmail) + err = service.ChangeEmail(userCtx2, newEmail) require.Error(t, err) }) @@ -257,19 +257,19 @@ func TestService(t *testing.T) { ProjectID: up2Pro1.ID, } - _, err := sat.API.Buckets.Service.CreateBucket(authCtx2, bucket1) + _, err := sat.API.Buckets.Service.CreateBucket(userCtx2, bucket1) require.NoError(t, err) - _, err = sat.API.Buckets.Service.CreateBucket(authCtx2, bucket2) + _, err = sat.API.Buckets.Service.CreateBucket(userCtx2, bucket2) require.NoError(t, err) - bucketNames, err := service.GetAllBucketNames(authCtx2, up2Pro1.ID) + bucketNames, err := service.GetAllBucketNames(userCtx2, up2Pro1.ID) require.NoError(t, err) require.Equal(t, bucket1.Name, bucketNames[0]) require.Equal(t, bucket2.Name, bucketNames[1]) // Getting someone else buckets should not work - bucketsForUnauthorizedUser, err := service.GetAllBucketNames(authCtx1, up2Pro1.ID) + bucketsForUnauthorizedUser, err := service.GetAllBucketNames(userCtx1, up2Pro1.ID) require.Error(t, err) require.Nil(t, bucketsForUnauthorizedUser) }) @@ -295,10 +295,10 @@ func TestService(t *testing.T) { require.NotNil(t, info) // Deleting someone else api keys should not work - err = service.DeleteAPIKeyByNameAndProjectID(authCtx1, apikey.Name, up2Pro1.ID) + err = service.DeleteAPIKeyByNameAndProjectID(userCtx1, apikey.Name, up2Pro1.ID) require.Error(t, err) - err = service.DeleteAPIKeyByNameAndProjectID(authCtx2, apikey.Name, up2Pro1.ID) + err = service.DeleteAPIKeyByNameAndProjectID(userCtx2, apikey.Name, up2Pro1.ID) require.NoError(t, err) info, err = sat.DB.Console().APIKeys().Get(ctx, createdKey.ID) @@ -351,11 +351,11 @@ func TestPaidTier(t *testing.T) { require.NoError(t, err) require.False(t, user.PaidTier) - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) + userCtx, err := sat.UserContext(ctx, user.ID) require.NoError(t, err) // add a credit card to the user - err = service.Payments().AddCreditCard(authCtx, "test-cc-token") + err = service.Payments().AddCreditCard(userCtx, "test-cc-token") require.NoError(t, err) // expect user to be in paid tier @@ -365,18 +365,18 @@ func TestPaidTier(t *testing.T) { require.Equal(t, usageConfig.Project.Paid, user.ProjectLimit) // update auth ctx - authCtx, err = sat.AuthenticatedContext(ctx, user.ID) + userCtx, err = sat.UserContext(ctx, user.ID) require.NoError(t, err) // expect project to be migrated to paid tier usage limits - proj1, err = service.GetProject(authCtx, proj1.ID) + proj1, err = service.GetProject(userCtx, proj1.ID) require.NoError(t, err) require.Equal(t, usageConfig.Storage.Paid, *proj1.StorageLimit) require.Equal(t, usageConfig.Bandwidth.Paid, *proj1.BandwidthLimit) require.Equal(t, usageConfig.Segment.Paid, *proj1.SegmentLimit) // expect new project to be created with paid tier usage limits - proj2, err := service.CreateProject(authCtx, console.ProjectInfo{Name: "Project 2"}) + proj2, err := service.CreateProject(userCtx, console.ProjectInfo{Name: "Project 2"}) require.NoError(t, err) require.Equal(t, usageConfig.Storage.Paid, *proj2.StorageLimit) }) @@ -395,22 +395,22 @@ func TestMFA(t *testing.T) { }, 1) require.NoError(t, err) - getNewAuthorization := func() (context.Context, console.Authorization) { - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) + updateContext := func() (context.Context, *console.User) { + userCtx, err := sat.UserContext(ctx, user.ID) require.NoError(t, err) - auth, err := console.GetAuth(authCtx) + user, err := console.GetUser(userCtx) require.NoError(t, err) - return authCtx, auth + return userCtx, user } - authCtx, auth := getNewAuthorization() + userCtx, user := updateContext() var key string t.Run("TestResetMFASecretKey", func(t *testing.T) { - key, err = service.ResetMFASecretKey(authCtx) + key, err = service.ResetMFASecretKey(userCtx) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.NotEmpty(t, auth.User.MFASecretKey) + _, user := updateContext() + require.NotEmpty(t, user.MFASecretKey) }) t.Run("TestEnableUserMFABadPasscode", func(t *testing.T) { @@ -418,15 +418,15 @@ func TestMFA(t *testing.T) { badCode, err := console.NewMFAPasscode(key, time.Time{}.Add(time.Hour)) require.NoError(t, err) - err = service.EnableUserMFA(authCtx, badCode, time.Time{}) + err = service.EnableUserMFA(userCtx, badCode, time.Time{}) require.True(t, console.ErrValidation.Has(err)) - authCtx, auth = getNewAuthorization() - _, err = service.ResetMFARecoveryCodes(authCtx) + userCtx, _ = updateContext() + _, err = service.ResetMFARecoveryCodes(userCtx) require.True(t, console.ErrUnauthorized.Has(err)) - authCtx, auth = getNewAuthorization() - require.False(t, auth.User.MFAEnabled) + _, user = updateContext() + require.False(t, user.MFAEnabled) }) t.Run("TestEnableUserMFAGoodPasscode", func(t *testing.T) { @@ -434,13 +434,13 @@ func TestMFA(t *testing.T) { goodCode, err := console.NewMFAPasscode(key, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - err = service.EnableUserMFA(authCtx, goodCode, time.Time{}) + userCtx, _ = updateContext() + err = service.EnableUserMFA(userCtx, goodCode, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.True(t, auth.User.MFAEnabled) - require.Equal(t, auth.User.MFASecretKey, key) + _, user = updateContext() + require.True(t, user.MFAEnabled) + require.Equal(t, user.MFASecretKey, key) }) t.Run("TestMFAGetToken", func(t *testing.T) { @@ -471,13 +471,13 @@ func TestMFA(t *testing.T) { }) t.Run("TestMFARecoveryCodes", func(t *testing.T) { - _, err = service.ResetMFARecoveryCodes(authCtx) + _, err = service.ResetMFARecoveryCodes(userCtx) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.Len(t, auth.User.MFARecoveryCodes, console.MFARecoveryCodeCount) + _, user = updateContext() + require.Len(t, user.MFARecoveryCodes, console.MFARecoveryCodeCount) - for _, code := range auth.User.MFARecoveryCodes { + for _, code := range user.MFARecoveryCodes { // Ensure code is of the form XXXX-XXXX-XXXX where X is A-Z or 0-9. require.Regexp(t, "^([A-Z0-9]{4})((-[A-Z0-9]{4})){2}$", code) @@ -492,10 +492,11 @@ func TestMFA(t *testing.T) { require.True(t, console.ErrMFARecoveryCode.Has(err)) require.Empty(t, token) - authCtx, auth = getNewAuthorization() + _, user = updateContext() } - _, err = service.ResetMFARecoveryCodes(authCtx) + userCtx, _ = updateContext() + _, err = service.ResetMFARecoveryCodes(userCtx) require.NoError(t, err) }) @@ -504,14 +505,14 @@ func TestMFA(t *testing.T) { badCode, err := console.NewMFAPasscode(key, time.Time{}.Add(time.Hour)) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - err = service.DisableUserMFA(authCtx, badCode, time.Time{}, "") + userCtx, _ = updateContext() + err = service.DisableUserMFA(userCtx, badCode, time.Time{}, "") require.True(t, console.ErrValidation.Has(err)) - authCtx, auth = getNewAuthorization() - require.True(t, auth.User.MFAEnabled) - require.NotEmpty(t, auth.User.MFASecretKey) - require.NotEmpty(t, auth.User.MFARecoveryCodes) + _, user = updateContext() + require.True(t, user.MFAEnabled) + require.NotEmpty(t, user.MFASecretKey) + require.NotEmpty(t, user.MFARecoveryCodes) }) t.Run("TestDisableUserMFAConflict", func(t *testing.T) { @@ -519,14 +520,14 @@ func TestMFA(t *testing.T) { goodCode, err := console.NewMFAPasscode(key, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - err = service.DisableUserMFA(authCtx, goodCode, time.Time{}, auth.User.MFARecoveryCodes[0]) + userCtx, user = updateContext() + err = service.DisableUserMFA(userCtx, goodCode, time.Time{}, user.MFARecoveryCodes[0]) require.True(t, console.ErrMFAConflict.Has(err)) - authCtx, auth = getNewAuthorization() - require.True(t, auth.User.MFAEnabled) - require.NotEmpty(t, auth.User.MFASecretKey) - require.NotEmpty(t, auth.User.MFARecoveryCodes) + _, user = updateContext() + require.True(t, user.MFAEnabled) + require.NotEmpty(t, user.MFASecretKey) + require.NotEmpty(t, user.MFARecoveryCodes) }) t.Run("TestDisableUserMFAGoodPasscode", func(t *testing.T) { @@ -534,46 +535,46 @@ func TestMFA(t *testing.T) { goodCode, err := console.NewMFAPasscode(key, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - err = service.DisableUserMFA(authCtx, goodCode, time.Time{}, "") + userCtx, _ = updateContext() + err = service.DisableUserMFA(userCtx, goodCode, time.Time{}, "") require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.False(t, auth.User.MFAEnabled) - require.Empty(t, auth.User.MFASecretKey) - require.Empty(t, auth.User.MFARecoveryCodes) + userCtx, user = updateContext() + require.False(t, user.MFAEnabled) + require.Empty(t, user.MFASecretKey) + require.Empty(t, user.MFARecoveryCodes) }) t.Run("TestDisableUserMFAGoodRecoveryCode", func(t *testing.T) { // Expect MFA-disabling attempt to succeed when providing valid recovery code. // Enable MFA - key, err = service.ResetMFASecretKey(authCtx) + key, err = service.ResetMFASecretKey(userCtx) require.NoError(t, err) goodCode, err := console.NewMFAPasscode(key, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - err = service.EnableUserMFA(authCtx, goodCode, time.Time{}) + userCtx, _ = updateContext() + err = service.EnableUserMFA(userCtx, goodCode, time.Time{}) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - _, err = service.ResetMFARecoveryCodes(authCtx) + userCtx, _ = updateContext() + _, err = service.ResetMFARecoveryCodes(userCtx) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.True(t, auth.User.MFAEnabled) - require.NotEmpty(t, auth.User.MFASecretKey) - require.NotEmpty(t, auth.User.MFARecoveryCodes) + userCtx, user = updateContext() + require.True(t, user.MFAEnabled) + require.NotEmpty(t, user.MFASecretKey) + require.NotEmpty(t, user.MFARecoveryCodes) // Disable MFA - err = service.DisableUserMFA(authCtx, "", time.Time{}, auth.User.MFARecoveryCodes[0]) + err = service.DisableUserMFA(userCtx, "", time.Time{}, user.MFARecoveryCodes[0]) require.NoError(t, err) - authCtx, auth = getNewAuthorization() - require.False(t, auth.User.MFAEnabled) - require.Empty(t, auth.User.MFASecretKey) - require.Empty(t, auth.User.MFARecoveryCodes) + _, user = updateContext() + require.False(t, user.MFAEnabled) + require.Empty(t, user.MFASecretKey) + require.Empty(t, user.MFARecoveryCodes) }) }) } @@ -620,23 +621,18 @@ func TestResetPassword(t *testing.T) { token = getNewResetToken() // Enable MFA. - getNewAuthorization := func() (context.Context, console.Authorization) { - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) - require.NoError(t, err) - auth, err := console.GetAuth(authCtx) - require.NoError(t, err) - return authCtx, auth - } - authCtx, _ := getNewAuthorization() - - key, err := service.ResetMFASecretKey(authCtx) + userCtx, err := sat.UserContext(ctx, user.ID) + require.NoError(t, err) + + key, err := service.ResetMFASecretKey(userCtx) + require.NoError(t, err) + userCtx, err = sat.UserContext(ctx, user.ID) require.NoError(t, err) - authCtx, auth := getNewAuthorization() passcode, err := console.NewMFAPasscode(key, token.CreatedAt) require.NoError(t, err) - err = service.EnableUserMFA(authCtx, passcode, token.CreatedAt) + err = service.EnableUserMFA(userCtx, passcode, token.CreatedAt) require.NoError(t, err) // Expect error when providing bad passcode. @@ -645,7 +641,7 @@ func TestResetPassword(t *testing.T) { err = service.ResetPassword(ctx, token.Secret.String(), newPass, badPasscode, "", token.CreatedAt) require.True(t, console.ErrMFAPasscode.Has(err)) - for _, recoveryCode := range auth.User.MFARecoveryCodes { + for _, recoveryCode := range user.MFARecoveryCodes { // Expect success when providing bad passcode and good recovery code. err = service.ResetPassword(ctx, token.Secret.String(), newPass, badPasscode, recoveryCode, token.CreatedAt) require.NoError(t, err) @@ -662,40 +658,6 @@ func TestResetPassword(t *testing.T) { }) } -// TestActivateAccountToken ensures that the token returned after activating an account can be used to authorize user activity. -// i.e. a user does not need to acquire an authorization separate from the account activation step. -func TestActivateAccountToken(t *testing.T) { - testplanet.Run(t, testplanet.Config{ - SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, - }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { - sat := planet.Satellites[0] - service := sat.API.Console.Service - - createUser := console.CreateUser{ - FullName: "Alice", - ShortName: "Alice", - Email: "alice@mail.test", - Password: "123a123", - } - - regToken, err := service.CreateRegToken(ctx, 1) - require.NoError(t, err) - - rootUser, err := service.CreateUser(ctx, createUser, regToken.Secret) - require.NoError(t, err) - - activationToken, err := service.GenerateActivationToken(ctx, rootUser.ID, rootUser.Email) - require.NoError(t, err) - - authToken, err := service.ActivateAccount(ctx, activationToken) - require.NoError(t, err) - - _, err = service.Authorize(consoleauth.WithAPIKey(ctx, []byte(authToken))) - require.NoError(t, err) - - }) -} - func TestRESTKeys(t *testing.T) { testplanet.Run(t, testplanet.Config{ SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1, @@ -709,23 +671,23 @@ func TestRESTKeys(t *testing.T) { user, err := service.GetUser(ctx, proj1.OwnerID) require.NoError(t, err) - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) + userCtx, err := sat.UserContext(ctx, user.ID) require.NoError(t, err) now := time.Now() expires := 5 * time.Hour - apiKey, expiresAt, err := service.CreateRESTKey(authCtx, expires) + apiKey, expiresAt, err := service.CreateRESTKey(userCtx, expires) require.NoError(t, err) require.NotEmpty(t, apiKey) require.True(t, expiresAt.After(now)) require.True(t, expiresAt.Before(now.Add(expires+time.Hour))) // test revocation - require.NoError(t, service.RevokeRESTKey(authCtx, apiKey)) + require.NoError(t, service.RevokeRESTKey(userCtx, apiKey)) // test revoke non existent key nonexistent := testrand.UUID() - err = service.RevokeRESTKey(authCtx, nonexistent.String()) + err = service.RevokeRESTKey(userCtx, nonexistent.String()) require.Error(t, err) }) } @@ -748,23 +710,23 @@ func TestLockAccount(t *testing.T) { user, err := sat.AddUser(ctx, newUser, 1) require.NoError(t, err) - getNewAuthorization := func() (context.Context, console.Authorization) { - authCtx, err := sat.AuthenticatedContext(ctx, user.ID) + updateContext := func() (context.Context, *console.User) { + userCtx, err := sat.UserContext(ctx, user.ID) require.NoError(t, err) - auth, err := console.GetAuth(authCtx) + user, err := console.GetUser(userCtx) require.NoError(t, err) - return authCtx, auth + return userCtx, user } - authCtx, _ := getNewAuthorization() - secret, err := service.ResetMFASecretKey(authCtx) + userCtx, _ := updateContext() + secret, err := service.ResetMFASecretKey(userCtx) require.NoError(t, err) goodCode0, err := console.NewMFAPasscode(secret, time.Time{}) require.NoError(t, err) - authCtx, _ = getNewAuthorization() - err = service.EnableUserMFA(authCtx, goodCode0, time.Time{}) + userCtx, _ = updateContext() + err = service.EnableUserMFA(userCtx, goodCode0, time.Time{}) require.NoError(t, err) now := time.Now() @@ -795,7 +757,7 @@ func TestLockAccount(t *testing.T) { } } - lockedUser, err := service.GetUser(authCtx, user.ID) + lockedUser, err := service.GetUser(userCtx, user.ID) require.NoError(t, err) require.True(t, lockedUser.FailedLoginCount == consoleConfig.LoginAttemptsWithoutPenalty) require.True(t, lockedUser.LoginLockoutExpiration.After(now)) @@ -803,10 +765,10 @@ func TestLockAccount(t *testing.T) { // lock account once again and check if lockout expiration time increased. expDuration := time.Duration(math.Pow(consoleConfig.FailedLoginPenalty, float64(lockedUser.FailedLoginCount-1))) * time.Minute lockoutExpDate := now.Add(expDuration) - err = service.UpdateUsersFailedLoginState(authCtx, lockedUser, lockoutExpDate) + err = service.UpdateUsersFailedLoginState(userCtx, lockedUser, lockoutExpDate) require.NoError(t, err) - lockedUser, err = service.GetUser(authCtx, user.ID) + lockedUser, err = service.GetUser(userCtx, user.ID) require.NoError(t, err) require.True(t, lockedUser.FailedLoginCount == consoleConfig.LoginAttemptsWithoutPenalty+1) @@ -815,7 +777,7 @@ func TestLockAccount(t *testing.T) { // unlock account by successful login lockedUser.LoginLockoutExpiration = now.Add(-time.Second) - err = usersDB.Update(authCtx, lockedUser) + err = usersDB.Update(userCtx, lockedUser) require.NoError(t, err) authUser.Password = newUser.FullName @@ -823,7 +785,7 @@ func TestLockAccount(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, token) - unlockedUser, err := service.GetUser(authCtx, user.ID) + unlockedUser, err := service.GetUser(userCtx, user.ID) require.NoError(t, err) require.Zero(t, unlockedUser.FailedLoginCount) @@ -839,7 +801,7 @@ func TestLockAccount(t *testing.T) { } } - lockedUser, err = service.GetUser(authCtx, user.ID) + lockedUser, err = service.GetUser(userCtx, user.ID) require.NoError(t, err) require.True(t, lockedUser.FailedLoginCount == consoleConfig.LoginAttemptsWithoutPenalty) require.True(t, lockedUser.LoginLockoutExpiration.After(now)) @@ -847,7 +809,7 @@ func TestLockAccount(t *testing.T) { // unlock account lockedUser.LoginLockoutExpiration = time.Time{} lockedUser.FailedLoginCount = 0 - err = usersDB.Update(authCtx, lockedUser) + err = usersDB.Update(userCtx, lockedUser) require.NoError(t, err) // check if user's account gets locked because of providing wrong mfa recovery code. @@ -863,7 +825,7 @@ func TestLockAccount(t *testing.T) { } } - lockedUser, err = service.GetUser(authCtx, user.ID) + lockedUser, err = service.GetUser(userCtx, user.ID) require.NoError(t, err) require.True(t, lockedUser.FailedLoginCount == consoleConfig.LoginAttemptsWithoutPenalty) require.True(t, lockedUser.LoginLockoutExpiration.After(now)) @@ -882,3 +844,43 @@ func TestWalletJsonMarshall(t *testing.T) { require.Contains(t, string(out), "\"balance\":100") } + +func TestSessionExpiration(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Console.SessionDuration = time.Hour + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + sat := planet.Satellites[0] + service := sat.API.Console.Service + + user, err := sat.AddUser(ctx, console.CreateUser{ + FullName: "Test User", + Email: "test@mail.test", + }, 1) + require.NoError(t, err) + + // Session should be added to DB after token request + token, err := service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName}) + require.NoError(t, err) + + _, err = service.TokenAuth(ctx, token, time.Now()) + require.NoError(t, err) + + sessionID, err := uuid.FromBytes(token.Payload) + require.NoError(t, err) + + _, err = sat.DB.Console().WebappSessions().GetBySessionID(ctx, sessionID) + require.NoError(t, err) + + // Session should be removed from DB after it has expired + _, err = service.TokenAuth(ctx, token, time.Now().Add(2*time.Hour)) + require.True(t, console.ErrTokenExpiration.Has(err)) + + _, err = sat.DB.Console().WebappSessions().GetBySessionID(ctx, sessionID) + require.ErrorIs(t, sql.ErrNoRows, err) + }) +} diff --git a/satellite/console/users.go b/satellite/console/users.go index f5b712c00..f7e80e806 100644 --- a/satellite/console/users.go +++ b/satellite/console/users.go @@ -117,6 +117,8 @@ type AuthUser struct { Password string `json:"password"` MFAPasscode string `json:"mfaPasscode"` MFARecoveryCode string `json:"mfaRecoveryCode"` + IP string `json:"-"` + UserAgent string `json:"-"` } // UserStatus - is used to indicate status of the users account. @@ -193,3 +195,34 @@ type ResponseUser struct { MFAEnabled bool `json:"isMFAEnabled"` MFARecoveryCodeCount int `json:"mfaRecoveryCodeCount"` } + +// key is a context value key type. +type key int + +// userKey is context key for User. +const userKey key = 0 + +// WithUser creates new context with User. +func WithUser(ctx context.Context, user *User) context.Context { + return context.WithValue(ctx, userKey, user) +} + +// WithUserFailure creates new context with User failure. +func WithUserFailure(ctx context.Context, err error) context.Context { + return context.WithValue(ctx, userKey, err) +} + +// GetUser gets User from context. +func GetUser(ctx context.Context) (*User, error) { + value := ctx.Value(userKey) + + if user, ok := value.(*User); ok { + return user, nil + } + + if err, ok := value.(error); ok { + return nil, Error.Wrap(err) + } + + return nil, Error.New("user is not in context") +} diff --git a/satellite/metainfo/attribution_test.go b/satellite/metainfo/attribution_test.go index a5f9677a8..6c995e33f 100644 --- a/satellite/metainfo/attribution_test.go +++ b/satellite/metainfo/attribution_test.go @@ -103,10 +103,10 @@ func TestBucketAttribution(t *testing.T) { require.NoError(t, err) createBucketAndCheckAttribution := func(userID uuid.UUID, apiKeyName, bucketName string) { - authCtx, err := satellite.AuthenticatedContext(ctx, userID) + userCtx, err := satellite.UserContext(ctx, userID) require.NoError(t, err, errTag) - _, apiKeyInfo, err := satellite.API.Console.Service.CreateAPIKey(authCtx, satProject.ID, apiKeyName) + _, apiKeyInfo, err := satellite.API.Console.Service.CreateAPIKey(userCtx, satProject.ID, apiKeyName) require.NoError(t, err, errTag) config := uplink.Config{ @@ -168,10 +168,10 @@ func TestQueryAttribution(t *testing.T) { satProject, err := satellite.AddProject(ctx, user.ID, "test") require.NoError(t, err) - authCtx, err := satellite.AuthenticatedContext(ctx, user.ID) + userCtx, err := satellite.UserContext(ctx, user.ID) require.NoError(t, err) - _, apiKeyInfo, err := satellite.API.Console.Service.CreateAPIKey(authCtx, satProject.ID, "root") + _, apiKeyInfo, err := satellite.API.Console.Service.CreateAPIKey(userCtx, satProject.ID, "root") require.NoError(t, err) access, err := uplink.RequestAccessWithPassphrase(ctx, satellite.NodeURL().String(), apiKeyInfo.Serialize(), "mypassphrase") diff --git a/satellite/oidc/endpoint.go b/satellite/oidc/endpoint.go index 67d721dc4..7cc226115 100644 --- a/satellite/oidc/endpoint.go +++ b/satellite/oidc/endpoint.go @@ -53,12 +53,12 @@ func NewEndpoint( svr := server.NewDefaultServer(manager) svr.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { - auth, err := console.GetAuth(r.Context()) + user, err := console.GetUser(r.Context()) if err != nil { return "", console.ErrUnauthorized.Wrap(err) } - return auth.User.ID.String(), nil + return user.ID.String(), nil }) // externalAddress _should_ end with a '/' suffix based on the calling path diff --git a/satellite/oidc/integration_test.go b/satellite/oidc/integration_test.go index e6bd2e69a..1fb3d77ea 100644 --- a/satellite/oidc/integration_test.go +++ b/satellite/oidc/integration_test.go @@ -30,7 +30,6 @@ import ( "storj.io/storj/private/testplanet" "storj.io/storj/satellite" "storj.io/storj/satellite/console" - "storj.io/storj/satellite/console/consoleauth" "storj.io/storj/satellite/oidc" "storj.io/uplink" ) @@ -122,18 +121,15 @@ func TestOIDC(t *testing.T) { activationToken, err := sat.API.Console.Service.GenerateActivationToken(ctx, user.ID, user.Email) require.NoError(t, err) - consoleToken, err := sat.API.Console.Service.ActivateAccount(ctx, activationToken) + user, err = sat.API.Console.Service.ActivateAccount(ctx, activationToken) + require.NoError(t, err) + + sessionToken, err := sat.API.Console.Service.GenerateSessionToken(ctx, user.ID, user.Email, "", "") require.NoError(t, err) // Set up a test project and bucket - authed := console.WithAuth(ctx, console.Authorization{ - User: *user, - Claims: consoleauth.Claims{ - ID: user.ID, - Email: user.Email, - }, - }) + authed := console.WithUser(ctx, user) project, err := sat.API.Console.Service.CreateProject(authed, console.ProjectInfo{ Name: "test", @@ -250,7 +246,7 @@ func TestOIDC(t *testing.T) { { body := strings.NewReader(consent.Encode()) - send(t, body, &token, http.StatusOK, authEndpoint, http.MethodPost, consoleToken, "application/x-www-form-urlencoded") + send(t, body, &token, http.StatusOK, authEndpoint, http.MethodPost, sessionToken.String(), "application/x-www-form-urlencoded") } require.Equal(t, "Bearer", token.TokenType) diff --git a/satellite/oidc/oauth_generates.go b/satellite/oidc/oauth_generates.go index 0438e20b3..4c07be368 100644 --- a/satellite/oidc/oauth_generates.go +++ b/satellite/oidc/oauth_generates.go @@ -15,7 +15,6 @@ import ( "storj.io/common/macaroon" "storj.io/common/uuid" "storj.io/storj/satellite/console" - "storj.io/storj/satellite/console/consoleauth" ) // UUIDAuthorizeGenerate generates an auth code using Storj's uuid. @@ -65,13 +64,7 @@ func (a *MacaroonAccessGenerate) apiKeyForProject(ctx context.Context, data *oau return nil, err } - ctx = console.WithAuth(ctx, console.Authorization{ - User: *user, - Claims: consoleauth.Claims{ - ID: user.ID, - Email: user.Email, - }, - }) + ctx = console.WithUser(ctx, user) oauthClient := data.Client.(OAuthClient) name := oauthClient.AppName + " / " + oauthClient.ID.String() diff --git a/scripts/testdata/satellite-config.yaml.lock b/scripts/testdata/satellite-config.yaml.lock index 4793d0ccb..f2fba104f 100755 --- a/scripts/testdata/satellite-config.yaml.lock +++ b/scripts/testdata/satellite-config.yaml.lock @@ -85,7 +85,7 @@ compensation.rates.put-tb: "0" # comma separated monthly withheld percentage rates compensation.withheld-percents: 75,75,75,50,50,50,25,25,25,0,0,0,0,0,0 -# expiration time for auth tokens, account recovery tokens, and activation tokens +# expiration time for account recovery and activation tokens # console-auth.token-expiration-time: 24h0m0s # url link for account activation redirect @@ -244,6 +244,9 @@ compensation.withheld-percents: 75,75,75,50,50,50,25,25,25,0,0,0,0,0,0 # used to communicate with web crawlers and other web robots # console.seo: "User-agent: *\nDisallow: \nDisallow: /cgi-bin/" +# duration a session is valid for +# console.session-duration: 168h0m0s + # path to static resources # console.static-dir: ""