From 7530a3a83da655d56546baa04c9d51099d527e26 Mon Sep 17 00:00:00 2001 From: Moby von Briesen Date: Mon, 12 Jun 2023 16:25:53 -0400 Subject: [PATCH] satellite/console: add CORS middleware to satellite UI and API Add some basic handling to set cross-origin resource sharing headers for the satellite UI app handler as well as API endpoints used by the satellite UI. This change also removes some no-longer-necessary CORS functionality on the account registration endpoint. Previously, these CORS headers were used to enable account registration cross-origin from www.storj.io. However, we have since removed the ability to sign up via www.storj.io. With these changes, browsers will prevent any requests to the affected endpoints, unless the browser is making the request from the same host as the satellite. see https://github.com/storj/storj-private/issues/242 Change-Id: Ifd98be4a142a2e61e26392d97242d911e051fe8a --- .../console/consoleweb/consoleapi/auth.go | 21 +-- .../consoleweb/consoleapi/auth_test.go | 97 ------------ satellite/console/consoleweb/server.go | 146 +++++++++++------- 3 files changed, 91 insertions(+), 173 deletions(-) diff --git a/satellite/console/consoleweb/consoleapi/auth.go b/satellite/console/consoleweb/consoleapi/auth.go index 298615f33..f49f1279f 100644 --- a/satellite/console/consoleweb/consoleapi/auth.go +++ b/satellite/console/consoleweb/consoleapi/auth.go @@ -31,12 +31,6 @@ var ( // errNotImplemented is the error value used by handlers of this package to // response with status Not Implemented. errNotImplemented = errs.New("not implemented") - - // supportedCORSOrigins allows us to support visitors who sign up from the website. - supportedCORSOrigins = map[string]bool{ - "https://storj.io": true, - "https://www.storj.io": true, - } ) // Auth is an api controller that exposes all auth functionality. @@ -210,19 +204,6 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) { var err error defer mon.Task()(&ctx)(&err) - origin := r.Header.Get("Origin") - if supportedCORSOrigins[origin] { - // we should send the exact origin back, rather than a wildcard - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") - } - - // OPTIONS is a pre-flight check for cross-origin (CORS) permissions - if r.Method == "OPTIONS" { - return - } - var registerData struct { FullName string `json:"fullName"` ShortName string `json:"shortName"` @@ -352,7 +333,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) { FullName: user.FullName, Email: user.Email, Type: analytics.Personal, - OriginHeader: origin, + OriginHeader: r.Header.Get("Origin"), Referrer: referrer, HubspotUTK: hubspotUTK, UserAgent: string(user.UserAgent), diff --git a/satellite/console/consoleweb/consoleapi/auth_test.go b/satellite/console/consoleweb/consoleapi/auth_test.go index 2c2991234..240446532 100644 --- a/satellite/console/consoleweb/consoleapi/auth_test.go +++ b/satellite/console/consoleweb/consoleapi/auth_test.go @@ -107,103 +107,6 @@ func TestAuth_Register(t *testing.T) { }) } -func TestAuth_Register_CORS(t *testing.T) { - testplanet.Run(t, testplanet.Config{ - SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, - Reconfigure: testplanet.Reconfigure{ - Satellite: func(log *zap.Logger, index int, config *satellite.Config) { - config.Console.OpenRegistrationEnabled = true - config.Console.RateLimit.Burst = 10 - config.Mail.AuthType = "nomail" - }, - }, - }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { - email := "user@test.com" - fullName := "testuser" - jsonBody := []byte(fmt.Sprintf(`{"email":"%s","fullName":"%s","password":"abc123","shortName":"test"}`, email, fullName)) - url := planet.Satellites[0].ConsoleURL() + "/api/v0/auth/register" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - // 1. OPTIONS request - // 1.1 CORS headers should not be set with origin other than storj.io or www.storj.io - req.Header.Set("Origin", "https://someexternalorigin.test") - req.Method = http.MethodOptions - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Methods")) - require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers")) - require.NoError(t, resp.Body.Close()) - - // 1.2 CORS headers should be set with a domain of storj.io - req.Header.Set("Origin", "https://storj.io") - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, "https://storj.io", resp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods")) - allowedHeaders := strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ") - require.ElementsMatch(t, allowedHeaders, []string{ - "Content-Type", - "Content-Length", - "Accept", - "Accept-Encoding", - "X-CSRF-Token", - "Authorization", - }) - require.NoError(t, resp.Body.Close()) - - // 1.3 CORS headers should be set with a domain of www.storj.io - req.Header.Set("Origin", "https://www.storj.io") - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, "https://www.storj.io", resp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods")) - allowedHeaders = strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ") - require.ElementsMatch(t, allowedHeaders, []string{ - "Content-Type", - "Content-Length", - "Accept", - "Accept-Encoding", - "X-CSRF-Token", - "Authorization", - }) - require.NoError(t, resp.Body.Close()) - - // 2. POST request with origin www.storj.io - req.Method = http.MethodPost - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err) - defer func() { - err = resp.Body.Close() - require.NoError(t, err) - }() - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, "https://www.storj.io", resp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods")) - allowedHeaders = strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ") - require.ElementsMatch(t, allowedHeaders, []string{ - "Content-Type", - "Content-Length", - "Accept", - "Accept-Encoding", - "X-CSRF-Token", - "Authorization", - }) - - require.Len(t, planet.Satellites, 1) - // this works only because we configured 'nomail' above. Mail send simulator won't click to activation link. - _, users, err := planet.Satellites[0].API.Console.Service.GetUserByEmailWithUnverified(ctx, email) - require.NoError(t, err) - require.Len(t, users, 1) - require.Equal(t, fullName, users[0].FullName) - }) -} - func TestDeleteAccount(t *testing.T) { ctx := testcontext.New(t) log := testplanet.NewLogger(t) diff --git a/satellite/console/consoleweb/server.go b/satellite/console/consoleweb/server.go index 05deffe5b..6aa270f62 100644 --- a/satellite/console/consoleweb/server.go +++ b/satellite/console/consoleweb/server.go @@ -132,6 +132,7 @@ type Server struct { listener net.Listener server http.Server + router *mux.Router cookieAuth *consolewebauth.CookieAuth ipRateLimiter *web.RateLimiter userIDRateLimiter *web.RateLimiter @@ -239,6 +240,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc } router := mux.NewRouter() + server.router = router // N.B. This middleware has to be the first one because it has to be called // the earliest in the HTTP chain. router.Use(newTraceRequestMiddleware(logger, router)) @@ -252,95 +254,104 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc consoleapi.NewUserManagement(logger, mon, server.service, router, &apiAuth{&server}) } - router.HandleFunc("/api/v0/config", server.frontendConfigHandler) + router.Handle("/api/v0/config", server.withCORS(http.HandlerFunc(server.frontendConfigHandler))) - router.Handle("/api/v0/graphql", server.withAuth(http.HandlerFunc(server.graphqlHandler))) + router.Handle("/api/v0/graphql", server.withCORS(server.withAuth(http.HandlerFunc(server.graphqlHandler)))) router.HandleFunc("/registrationToken/", server.createRegistrationTokenHandler) router.HandleFunc("/robots.txt", server.seoHandler) projectsController := consoleapi.NewProjects(logger, service) projectsRouter := router.PathPrefix("/api/v0/projects").Subrouter() - projectsRouter.Handle("/{id}/salt", server.withAuth(http.HandlerFunc(projectsController.GetSalt))).Methods(http.MethodGet) - projectsRouter.Handle("/{id}/invite", server.withAuth(http.HandlerFunc(projectsController.InviteUsers))).Methods(http.MethodPost) - projectsRouter.Handle("/{id}/invite-link", server.withAuth(http.HandlerFunc(projectsController.GetInviteLink))).Methods(http.MethodGet) - projectsRouter.Handle("/invitations", server.withAuth(http.HandlerFunc(projectsController.GetUserInvitations))).Methods(http.MethodGet) - projectsRouter.Handle("/invitations/{id}/respond", server.withAuth(http.HandlerFunc(projectsController.RespondToInvitation))).Methods(http.MethodPost) + projectsRouter.Use(server.withCORS) + projectsRouter.Use(server.withAuth) + projectsRouter.Handle("/{id}/salt", http.HandlerFunc(projectsController.GetSalt)).Methods(http.MethodGet, http.MethodOptions) + projectsRouter.Handle("/{id}/invite", http.HandlerFunc(projectsController.InviteUsers)).Methods(http.MethodPost, http.MethodOptions) + projectsRouter.Handle("/{id}/invite-link", http.HandlerFunc(projectsController.GetInviteLink)).Methods(http.MethodGet, http.MethodOptions) + projectsRouter.Handle("/invitations", http.HandlerFunc(projectsController.GetUserInvitations)).Methods(http.MethodGet, http.MethodOptions) + projectsRouter.Handle("/invitations/{id}/respond", http.HandlerFunc(projectsController.RespondToInvitation)).Methods(http.MethodPost, http.MethodOptions) usageLimitsController := consoleapi.NewUsageLimits(logger, service) - projectsRouter.Handle("/{id}/usage-limits", server.withAuth(http.HandlerFunc(usageLimitsController.ProjectUsageLimits))).Methods(http.MethodGet) - projectsRouter.Handle("/usage-limits", server.withAuth(http.HandlerFunc(usageLimitsController.TotalUsageLimits))).Methods(http.MethodGet) - projectsRouter.Handle("/{id}/daily-usage", server.withAuth(http.HandlerFunc(usageLimitsController.DailyUsage))).Methods(http.MethodGet) + projectsRouter.Handle("/{id}/usage-limits", http.HandlerFunc(usageLimitsController.ProjectUsageLimits)).Methods(http.MethodGet, http.MethodOptions) + projectsRouter.Handle("/usage-limits", http.HandlerFunc(usageLimitsController.TotalUsageLimits)).Methods(http.MethodGet, http.MethodOptions) + projectsRouter.Handle("/{id}/daily-usage", http.HandlerFunc(usageLimitsController.DailyUsage)).Methods(http.MethodGet, http.MethodOptions) authController := consoleapi.NewAuth(logger, service, accountFreezeService, mailService, server.cookieAuth, server.analytics, config.SatelliteName, server.config.ExternalAddress, config.LetUsKnowURL, config.TermsAndConditionsURL, config.ContactInfoURL, config.GeneralRequestURL) authRouter := router.PathPrefix("/api/v0/auth").Subrouter() - authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.GetAccount))).Methods(http.MethodGet) - authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.UpdateAccount))).Methods(http.MethodPatch) - authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost) - authRouter.Handle("/account/change-password", server.withAuth(server.userIDRateLimiter.Limit(http.HandlerFunc(authController.ChangePassword)))).Methods(http.MethodPost) - authRouter.Handle("/account/freezestatus", server.withAuth(http.HandlerFunc(authController.GetFreezeStatus))).Methods(http.MethodGet) - authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.GetUserSettings))).Methods(http.MethodGet) - authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.SetUserSettings))).Methods(http.MethodPatch) - authRouter.Handle("/account/onboarding", server.withAuth(http.HandlerFunc(authController.SetOnboardingStatus))).Methods(http.MethodPatch) - authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost) - authRouter.Handle("/mfa/enable", server.withAuth(http.HandlerFunc(authController.EnableUserMFA))).Methods(http.MethodPost) - authRouter.Handle("/mfa/disable", server.withAuth(http.HandlerFunc(authController.DisableUserMFA))).Methods(http.MethodPost) - authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost) - authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost) - authRouter.Handle("/logout", server.withAuth(http.HandlerFunc(authController.Logout))).Methods(http.MethodPost) - authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost) - authRouter.Handle("/token-by-api-key", server.ipRateLimiter.Limit(http.HandlerFunc(authController.TokenByAPIKey))).Methods(http.MethodPost) + authRouter.Use(server.withCORS) + authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.GetAccount))).Methods(http.MethodGet, http.MethodOptions) + authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.UpdateAccount))).Methods(http.MethodPatch, http.MethodOptions) + authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/account/change-password", server.withAuth(server.userIDRateLimiter.Limit(http.HandlerFunc(authController.ChangePassword)))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/account/freezestatus", server.withAuth(http.HandlerFunc(authController.GetFreezeStatus))).Methods(http.MethodGet, http.MethodOptions) + authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.GetUserSettings))).Methods(http.MethodGet, http.MethodOptions) + authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.SetUserSettings))).Methods(http.MethodPatch, http.MethodOptions) + authRouter.Handle("/account/onboarding", server.withAuth(http.HandlerFunc(authController.SetOnboardingStatus))).Methods(http.MethodPatch, http.MethodOptions) + authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/mfa/enable", server.withAuth(http.HandlerFunc(authController.EnableUserMFA))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/mfa/disable", server.withAuth(http.HandlerFunc(authController.DisableUserMFA))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/logout", server.withAuth(http.HandlerFunc(authController.Logout))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/token-by-api-key", server.ipRateLimiter.Limit(http.HandlerFunc(authController.TokenByAPIKey))).Methods(http.MethodPost, http.MethodOptions) authRouter.Handle("/register", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions) - authRouter.Handle("/forgot-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost) - authRouter.Handle("/resend-email/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost) - authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost) - authRouter.Handle("/refresh-session", server.withAuth(http.HandlerFunc(authController.RefreshSession))).Methods(http.MethodPost) + authRouter.Handle("/forgot-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/resend-email/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost, http.MethodOptions) + authRouter.Handle("/refresh-session", server.withAuth(http.HandlerFunc(authController.RefreshSession))).Methods(http.MethodPost, http.MethodOptions) if config.ABTesting.Enabled { abController := consoleapi.NewABTesting(logger, abTesting) abRouter := router.PathPrefix("/api/v0/ab").Subrouter() - abRouter.Handle("/values", server.withAuth(http.HandlerFunc(abController.GetABValues))).Methods(http.MethodGet) - abRouter.Handle("/hit/{action}", server.withAuth(http.HandlerFunc(abController.SendHit))).Methods(http.MethodPost) + abRouter.Use(server.withCORS) + abRouter.Use(server.withAuth) + abRouter.Handle("/values", http.HandlerFunc(abController.GetABValues)).Methods(http.MethodGet, http.MethodOptions) + abRouter.Handle("/hit/{action}", http.HandlerFunc(abController.SendHit)).Methods(http.MethodPost, http.MethodOptions) } paymentController := consoleapi.NewPayments(logger, service, accountFreezeService, packagePlans) paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter() + paymentsRouter.Use(server.withCORS) paymentsRouter.Use(server.withAuth) - paymentsRouter.Handle("/cards", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCreditCard))).Methods(http.MethodPost) - paymentsRouter.HandleFunc("/cards", paymentController.MakeCreditCardDefault).Methods(http.MethodPatch) - paymentsRouter.HandleFunc("/cards", paymentController.ListCreditCards).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/cards/{cardId}", paymentController.RemoveCreditCard).Methods(http.MethodDelete) - paymentsRouter.HandleFunc("/account/charges", paymentController.ProjectsCharges).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/account/balance", paymentController.AccountBalance).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost) - paymentsRouter.HandleFunc("/wallet", paymentController.GetWallet).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/wallet", paymentController.ClaimWallet).Methods(http.MethodPost) - paymentsRouter.HandleFunc("/wallet/payments", paymentController.WalletPayments).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet) - paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch) - paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet) - paymentsRouter.HandleFunc("/pricing", paymentController.GetProjectUsagePriceModel).Methods(http.MethodGet) + paymentsRouter.Handle("/cards", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCreditCard))).Methods(http.MethodPost, http.MethodOptions) + paymentsRouter.HandleFunc("/cards", paymentController.MakeCreditCardDefault).Methods(http.MethodPatch, http.MethodOptions) + paymentsRouter.HandleFunc("/cards", paymentController.ListCreditCards).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/cards/{cardId}", paymentController.RemoveCreditCard).Methods(http.MethodDelete, http.MethodOptions) + paymentsRouter.HandleFunc("/account/charges", paymentController.ProjectsCharges).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/account/balance", paymentController.AccountBalance).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost, http.MethodOptions) + paymentsRouter.HandleFunc("/wallet", paymentController.GetWallet).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/wallet", paymentController.ClaimWallet).Methods(http.MethodPost, http.MethodOptions) + paymentsRouter.HandleFunc("/wallet/payments", paymentController.WalletPayments).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch, http.MethodOptions) + paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet, http.MethodOptions) + paymentsRouter.HandleFunc("/pricing", paymentController.GetProjectUsagePriceModel).Methods(http.MethodGet, http.MethodOptions) if config.PricingPackagesEnabled { - paymentsRouter.HandleFunc("/purchase-package", paymentController.PurchasePackage).Methods(http.MethodPost) - paymentsRouter.HandleFunc("/package-available", paymentController.PackageAvailable).Methods(http.MethodGet) + paymentsRouter.HandleFunc("/purchase-package", paymentController.PurchasePackage).Methods(http.MethodPost, http.MethodOptions) + paymentsRouter.HandleFunc("/package-available", paymentController.PackageAvailable).Methods(http.MethodGet, http.MethodOptions) } bucketsController := consoleapi.NewBuckets(logger, service) bucketsRouter := router.PathPrefix("/api/v0/buckets").Subrouter() + bucketsRouter.Use(server.withCORS) bucketsRouter.Use(server.withAuth) - bucketsRouter.HandleFunc("/bucket-names", bucketsController.AllBucketNames).Methods(http.MethodGet) + bucketsRouter.HandleFunc("/bucket-names", bucketsController.AllBucketNames).Methods(http.MethodGet, http.MethodOptions) apiKeysController := consoleapi.NewAPIKeys(logger, service) apiKeysRouter := router.PathPrefix("/api/v0/api-keys").Subrouter() + apiKeysRouter.Use(server.withCORS) apiKeysRouter.Use(server.withAuth) - apiKeysRouter.HandleFunc("/delete-by-name", apiKeysController.DeleteByNameAndProjectID).Methods(http.MethodDelete) - apiKeysRouter.HandleFunc("/api-key-names", apiKeysController.GetAllAPIKeyNames).Methods(http.MethodGet) + apiKeysRouter.HandleFunc("/delete-by-name", apiKeysController.DeleteByNameAndProjectID).Methods(http.MethodDelete, http.MethodOptions) + apiKeysRouter.HandleFunc("/api-key-names", apiKeysController.GetAllAPIKeyNames).Methods(http.MethodGet, http.MethodOptions) analyticsController := consoleapi.NewAnalytics(logger, service, server.analytics) analyticsRouter := router.PathPrefix("/api/v0/analytics").Subrouter() + analyticsRouter.Use(server.withCORS) analyticsRouter.Use(server.withAuth) - analyticsRouter.HandleFunc("/event", analyticsController.EventTriggered).Methods(http.MethodPost) - analyticsRouter.HandleFunc("/page", analyticsController.PageEventTriggered).Methods(http.MethodPost) + analyticsRouter.HandleFunc("/event", analyticsController.EventTriggered).Methods(http.MethodPost, http.MethodOptions) + analyticsRouter.HandleFunc("/page", analyticsController.PageEventTriggered).Methods(http.MethodPost, http.MethodOptions) if server.config.StaticDir != "" { oidc := oidc.NewEndpoint( @@ -356,7 +367,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc router.Handle("/oauth/v2/clients/{id}", server.withAuth(http.HandlerFunc(oidc.GetClient))).Methods(http.MethodGet) fs := http.FileServer(http.Dir(server.config.StaticDir)) - router.PathPrefix("/static/").Handler(server.brotliMiddleware(http.StripPrefix("/static", fs))) + router.PathPrefix("/static/").Handler(server.withCORS(server.brotliMiddleware(http.StripPrefix("/static", fs)))) router.HandleFunc("/invited", server.handleInvited) @@ -367,9 +378,9 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc slashRouter.HandleFunc("/cancel-password-recovery", server.cancelPasswordRecoveryHandler) if server.config.UseVuetifyProject { - router.PathPrefix("/vuetifypoc").Handler(http.HandlerFunc(server.vuetifyAppHandler)) + router.PathPrefix("/vuetifypoc").Handler(server.withCORS(http.HandlerFunc(server.vuetifyAppHandler))) } - router.PathPrefix("/").Handler(http.HandlerFunc(server.appHandler)) + router.PathPrefix("/").Handler(server.withCORS(http.HandlerFunc(server.appHandler))) } server.server = http.Server{ @@ -506,6 +517,29 @@ func (server *Server) vuetifyAppHandler(w http.ResponseWriter, r *http.Request) http.ServeContent(w, r, path, info.ModTime(), file) } +// withCORS handles setting CORS-related headers on an http request. +func (server *Server) withCORS(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", strings.Trim(server.config.ExternalAddress, "/")) + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") + w.Header().Set("Access-Control-Expose-Headers", "*, Authorization") + + if r.Method == http.MethodOptions { + match := &mux.RouteMatch{} + if server.router.Match(r, match) { + methods, err := match.Route.GetMethods() + if err == nil && len(methods) > 0 { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ", ")) + } + } + return + } + + handler.ServeHTTP(w, r) + }) +} + // withAuth performs initial authorization before every request. func (server *Server) withAuth(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {