satellite/console: add CORS middleware to satellite UI and API

Add some basic handling to set cross-origin resource sharing headers for
the satellite UI app handler as well as API endpoints used by the
satellite UI.

This change also removes some no-longer-necessary CORS functionality on
the account registration endpoint. Previously, these CORS headers were
used to enable account registration cross-origin from www.storj.io.
However, we have since removed the ability to sign up via www.storj.io.

With these changes, browsers will prevent any requests to the affected
endpoints, unless the browser is making the request from the same host
as the satellite.

see https://github.com/storj/storj-private/issues/242

Change-Id: Ifd98be4a142a2e61e26392d97242d911e051fe8a
This commit is contained in:
Moby von Briesen 2023-06-12 16:25:53 -04:00 committed by Storj Robot
parent 361f9fdba5
commit 7530a3a83d
3 changed files with 91 additions and 173 deletions

View File

@ -31,12 +31,6 @@ var (
// errNotImplemented is the error value used by handlers of this package to // errNotImplemented is the error value used by handlers of this package to
// response with status Not Implemented. // response with status Not Implemented.
errNotImplemented = errs.New("not implemented") errNotImplemented = errs.New("not implemented")
// supportedCORSOrigins allows us to support visitors who sign up from the website.
supportedCORSOrigins = map[string]bool{
"https://storj.io": true,
"https://www.storj.io": true,
}
) )
// Auth is an api controller that exposes all auth functionality. // Auth is an api controller that exposes all auth functionality.
@ -210,19 +204,6 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
var err error var err error
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
origin := r.Header.Get("Origin")
if supportedCORSOrigins[origin] {
// we should send the exact origin back, rather than a wildcard
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
}
// OPTIONS is a pre-flight check for cross-origin (CORS) permissions
if r.Method == "OPTIONS" {
return
}
var registerData struct { var registerData struct {
FullName string `json:"fullName"` FullName string `json:"fullName"`
ShortName string `json:"shortName"` ShortName string `json:"shortName"`
@ -352,7 +333,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
FullName: user.FullName, FullName: user.FullName,
Email: user.Email, Email: user.Email,
Type: analytics.Personal, Type: analytics.Personal,
OriginHeader: origin, OriginHeader: r.Header.Get("Origin"),
Referrer: referrer, Referrer: referrer,
HubspotUTK: hubspotUTK, HubspotUTK: hubspotUTK,
UserAgent: string(user.UserAgent), UserAgent: string(user.UserAgent),

View File

@ -107,103 +107,6 @@ func TestAuth_Register(t *testing.T) {
}) })
} }
func TestAuth_Register_CORS(t *testing.T) {
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Console.OpenRegistrationEnabled = true
config.Console.RateLimit.Burst = 10
config.Mail.AuthType = "nomail"
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
email := "user@test.com"
fullName := "testuser"
jsonBody := []byte(fmt.Sprintf(`{"email":"%s","fullName":"%s","password":"abc123","shortName":"test"}`, email, fullName))
url := planet.Satellites[0].ConsoleURL() + "/api/v0/auth/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
// 1. OPTIONS request
// 1.1 CORS headers should not be set with origin other than storj.io or www.storj.io
req.Header.Set("Origin", "https://someexternalorigin.test")
req.Method = http.MethodOptions
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Methods"))
require.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
require.NoError(t, resp.Body.Close())
// 1.2 CORS headers should be set with a domain of storj.io
req.Header.Set("Origin", "https://storj.io")
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "https://storj.io", resp.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
allowedHeaders := strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ")
require.ElementsMatch(t, allowedHeaders, []string{
"Content-Type",
"Content-Length",
"Accept",
"Accept-Encoding",
"X-CSRF-Token",
"Authorization",
})
require.NoError(t, resp.Body.Close())
// 1.3 CORS headers should be set with a domain of www.storj.io
req.Header.Set("Origin", "https://www.storj.io")
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "https://www.storj.io", resp.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
allowedHeaders = strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ")
require.ElementsMatch(t, allowedHeaders, []string{
"Content-Type",
"Content-Length",
"Accept",
"Accept-Encoding",
"X-CSRF-Token",
"Authorization",
})
require.NoError(t, resp.Body.Close())
// 2. POST request with origin www.storj.io
req.Method = http.MethodPost
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() {
err = resp.Body.Close()
require.NoError(t, err)
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "https://www.storj.io", resp.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "POST, OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
allowedHeaders = strings.Split(resp.Header.Get("Access-Control-Allow-Headers"), ", ")
require.ElementsMatch(t, allowedHeaders, []string{
"Content-Type",
"Content-Length",
"Accept",
"Accept-Encoding",
"X-CSRF-Token",
"Authorization",
})
require.Len(t, planet.Satellites, 1)
// this works only because we configured 'nomail' above. Mail send simulator won't click to activation link.
_, users, err := planet.Satellites[0].API.Console.Service.GetUserByEmailWithUnverified(ctx, email)
require.NoError(t, err)
require.Len(t, users, 1)
require.Equal(t, fullName, users[0].FullName)
})
}
func TestDeleteAccount(t *testing.T) { func TestDeleteAccount(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
log := testplanet.NewLogger(t) log := testplanet.NewLogger(t)

View File

@ -132,6 +132,7 @@ type Server struct {
listener net.Listener listener net.Listener
server http.Server server http.Server
router *mux.Router
cookieAuth *consolewebauth.CookieAuth cookieAuth *consolewebauth.CookieAuth
ipRateLimiter *web.RateLimiter ipRateLimiter *web.RateLimiter
userIDRateLimiter *web.RateLimiter userIDRateLimiter *web.RateLimiter
@ -239,6 +240,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
} }
router := mux.NewRouter() router := mux.NewRouter()
server.router = router
// N.B. This middleware has to be the first one because it has to be called // N.B. This middleware has to be the first one because it has to be called
// the earliest in the HTTP chain. // the earliest in the HTTP chain.
router.Use(newTraceRequestMiddleware(logger, router)) router.Use(newTraceRequestMiddleware(logger, router))
@ -252,95 +254,104 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
consoleapi.NewUserManagement(logger, mon, server.service, router, &apiAuth{&server}) consoleapi.NewUserManagement(logger, mon, server.service, router, &apiAuth{&server})
} }
router.HandleFunc("/api/v0/config", server.frontendConfigHandler) router.Handle("/api/v0/config", server.withCORS(http.HandlerFunc(server.frontendConfigHandler)))
router.Handle("/api/v0/graphql", server.withAuth(http.HandlerFunc(server.graphqlHandler))) router.Handle("/api/v0/graphql", server.withCORS(server.withAuth(http.HandlerFunc(server.graphqlHandler))))
router.HandleFunc("/registrationToken/", server.createRegistrationTokenHandler) router.HandleFunc("/registrationToken/", server.createRegistrationTokenHandler)
router.HandleFunc("/robots.txt", server.seoHandler) router.HandleFunc("/robots.txt", server.seoHandler)
projectsController := consoleapi.NewProjects(logger, service) projectsController := consoleapi.NewProjects(logger, service)
projectsRouter := router.PathPrefix("/api/v0/projects").Subrouter() projectsRouter := router.PathPrefix("/api/v0/projects").Subrouter()
projectsRouter.Handle("/{id}/salt", server.withAuth(http.HandlerFunc(projectsController.GetSalt))).Methods(http.MethodGet) projectsRouter.Use(server.withCORS)
projectsRouter.Handle("/{id}/invite", server.withAuth(http.HandlerFunc(projectsController.InviteUsers))).Methods(http.MethodPost) projectsRouter.Use(server.withAuth)
projectsRouter.Handle("/{id}/invite-link", server.withAuth(http.HandlerFunc(projectsController.GetInviteLink))).Methods(http.MethodGet) projectsRouter.Handle("/{id}/salt", http.HandlerFunc(projectsController.GetSalt)).Methods(http.MethodGet, http.MethodOptions)
projectsRouter.Handle("/invitations", server.withAuth(http.HandlerFunc(projectsController.GetUserInvitations))).Methods(http.MethodGet) projectsRouter.Handle("/{id}/invite", http.HandlerFunc(projectsController.InviteUsers)).Methods(http.MethodPost, http.MethodOptions)
projectsRouter.Handle("/invitations/{id}/respond", server.withAuth(http.HandlerFunc(projectsController.RespondToInvitation))).Methods(http.MethodPost) projectsRouter.Handle("/{id}/invite-link", http.HandlerFunc(projectsController.GetInviteLink)).Methods(http.MethodGet, http.MethodOptions)
projectsRouter.Handle("/invitations", http.HandlerFunc(projectsController.GetUserInvitations)).Methods(http.MethodGet, http.MethodOptions)
projectsRouter.Handle("/invitations/{id}/respond", http.HandlerFunc(projectsController.RespondToInvitation)).Methods(http.MethodPost, http.MethodOptions)
usageLimitsController := consoleapi.NewUsageLimits(logger, service) usageLimitsController := consoleapi.NewUsageLimits(logger, service)
projectsRouter.Handle("/{id}/usage-limits", server.withAuth(http.HandlerFunc(usageLimitsController.ProjectUsageLimits))).Methods(http.MethodGet) projectsRouter.Handle("/{id}/usage-limits", http.HandlerFunc(usageLimitsController.ProjectUsageLimits)).Methods(http.MethodGet, http.MethodOptions)
projectsRouter.Handle("/usage-limits", server.withAuth(http.HandlerFunc(usageLimitsController.TotalUsageLimits))).Methods(http.MethodGet) projectsRouter.Handle("/usage-limits", http.HandlerFunc(usageLimitsController.TotalUsageLimits)).Methods(http.MethodGet, http.MethodOptions)
projectsRouter.Handle("/{id}/daily-usage", server.withAuth(http.HandlerFunc(usageLimitsController.DailyUsage))).Methods(http.MethodGet) projectsRouter.Handle("/{id}/daily-usage", http.HandlerFunc(usageLimitsController.DailyUsage)).Methods(http.MethodGet, http.MethodOptions)
authController := consoleapi.NewAuth(logger, service, accountFreezeService, mailService, server.cookieAuth, server.analytics, config.SatelliteName, server.config.ExternalAddress, config.LetUsKnowURL, config.TermsAndConditionsURL, config.ContactInfoURL, config.GeneralRequestURL) authController := consoleapi.NewAuth(logger, service, accountFreezeService, mailService, server.cookieAuth, server.analytics, config.SatelliteName, server.config.ExternalAddress, config.LetUsKnowURL, config.TermsAndConditionsURL, config.ContactInfoURL, config.GeneralRequestURL)
authRouter := router.PathPrefix("/api/v0/auth").Subrouter() authRouter := router.PathPrefix("/api/v0/auth").Subrouter()
authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.GetAccount))).Methods(http.MethodGet) authRouter.Use(server.withCORS)
authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.UpdateAccount))).Methods(http.MethodPatch) authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.GetAccount))).Methods(http.MethodGet, http.MethodOptions)
authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost) authRouter.Handle("/account", server.withAuth(http.HandlerFunc(authController.UpdateAccount))).Methods(http.MethodPatch, http.MethodOptions)
authRouter.Handle("/account/change-password", server.withAuth(server.userIDRateLimiter.Limit(http.HandlerFunc(authController.ChangePassword)))).Methods(http.MethodPost) authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/account/freezestatus", server.withAuth(http.HandlerFunc(authController.GetFreezeStatus))).Methods(http.MethodGet) authRouter.Handle("/account/change-password", server.withAuth(server.userIDRateLimiter.Limit(http.HandlerFunc(authController.ChangePassword)))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.GetUserSettings))).Methods(http.MethodGet) authRouter.Handle("/account/freezestatus", server.withAuth(http.HandlerFunc(authController.GetFreezeStatus))).Methods(http.MethodGet, http.MethodOptions)
authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.SetUserSettings))).Methods(http.MethodPatch) authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.GetUserSettings))).Methods(http.MethodGet, http.MethodOptions)
authRouter.Handle("/account/onboarding", server.withAuth(http.HandlerFunc(authController.SetOnboardingStatus))).Methods(http.MethodPatch) authRouter.Handle("/account/settings", server.withAuth(http.HandlerFunc(authController.SetUserSettings))).Methods(http.MethodPatch, http.MethodOptions)
authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost) authRouter.Handle("/account/onboarding", server.withAuth(http.HandlerFunc(authController.SetOnboardingStatus))).Methods(http.MethodPatch, http.MethodOptions)
authRouter.Handle("/mfa/enable", server.withAuth(http.HandlerFunc(authController.EnableUserMFA))).Methods(http.MethodPost) authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/mfa/disable", server.withAuth(http.HandlerFunc(authController.DisableUserMFA))).Methods(http.MethodPost) authRouter.Handle("/mfa/enable", server.withAuth(http.HandlerFunc(authController.EnableUserMFA))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost) authRouter.Handle("/mfa/disable", server.withAuth(http.HandlerFunc(authController.DisableUserMFA))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost) authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/logout", server.withAuth(http.HandlerFunc(authController.Logout))).Methods(http.MethodPost) authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost) authRouter.Handle("/logout", server.withAuth(http.HandlerFunc(authController.Logout))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/token-by-api-key", server.ipRateLimiter.Limit(http.HandlerFunc(authController.TokenByAPIKey))).Methods(http.MethodPost) authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/token-by-api-key", server.ipRateLimiter.Limit(http.HandlerFunc(authController.TokenByAPIKey))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/register", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions) authRouter.Handle("/register", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/forgot-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost) authRouter.Handle("/forgot-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/resend-email/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost) authRouter.Handle("/resend-email/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost) authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/refresh-session", server.withAuth(http.HandlerFunc(authController.RefreshSession))).Methods(http.MethodPost) authRouter.Handle("/refresh-session", server.withAuth(http.HandlerFunc(authController.RefreshSession))).Methods(http.MethodPost, http.MethodOptions)
if config.ABTesting.Enabled { if config.ABTesting.Enabled {
abController := consoleapi.NewABTesting(logger, abTesting) abController := consoleapi.NewABTesting(logger, abTesting)
abRouter := router.PathPrefix("/api/v0/ab").Subrouter() abRouter := router.PathPrefix("/api/v0/ab").Subrouter()
abRouter.Handle("/values", server.withAuth(http.HandlerFunc(abController.GetABValues))).Methods(http.MethodGet) abRouter.Use(server.withCORS)
abRouter.Handle("/hit/{action}", server.withAuth(http.HandlerFunc(abController.SendHit))).Methods(http.MethodPost) abRouter.Use(server.withAuth)
abRouter.Handle("/values", http.HandlerFunc(abController.GetABValues)).Methods(http.MethodGet, http.MethodOptions)
abRouter.Handle("/hit/{action}", http.HandlerFunc(abController.SendHit)).Methods(http.MethodPost, http.MethodOptions)
} }
paymentController := consoleapi.NewPayments(logger, service, accountFreezeService, packagePlans) paymentController := consoleapi.NewPayments(logger, service, accountFreezeService, packagePlans)
paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter() paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter()
paymentsRouter.Use(server.withCORS)
paymentsRouter.Use(server.withAuth) paymentsRouter.Use(server.withAuth)
paymentsRouter.Handle("/cards", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCreditCard))).Methods(http.MethodPost) paymentsRouter.Handle("/cards", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.AddCreditCard))).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.HandleFunc("/cards", paymentController.MakeCreditCardDefault).Methods(http.MethodPatch) paymentsRouter.HandleFunc("/cards", paymentController.MakeCreditCardDefault).Methods(http.MethodPatch, http.MethodOptions)
paymentsRouter.HandleFunc("/cards", paymentController.ListCreditCards).Methods(http.MethodGet) paymentsRouter.HandleFunc("/cards", paymentController.ListCreditCards).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/cards/{cardId}", paymentController.RemoveCreditCard).Methods(http.MethodDelete) paymentsRouter.HandleFunc("/cards/{cardId}", paymentController.RemoveCreditCard).Methods(http.MethodDelete, http.MethodOptions)
paymentsRouter.HandleFunc("/account/charges", paymentController.ProjectsCharges).Methods(http.MethodGet) paymentsRouter.HandleFunc("/account/charges", paymentController.ProjectsCharges).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/account/balance", paymentController.AccountBalance).Methods(http.MethodGet) paymentsRouter.HandleFunc("/account/balance", paymentController.AccountBalance).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost) paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.HandleFunc("/wallet", paymentController.GetWallet).Methods(http.MethodGet) paymentsRouter.HandleFunc("/wallet", paymentController.GetWallet).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/wallet", paymentController.ClaimWallet).Methods(http.MethodPost) paymentsRouter.HandleFunc("/wallet", paymentController.ClaimWallet).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.HandleFunc("/wallet/payments", paymentController.WalletPayments).Methods(http.MethodGet) paymentsRouter.HandleFunc("/wallet/payments", paymentController.WalletPayments).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet) paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch) paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch, http.MethodOptions)
paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet) paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet, http.MethodOptions)
paymentsRouter.HandleFunc("/pricing", paymentController.GetProjectUsagePriceModel).Methods(http.MethodGet) paymentsRouter.HandleFunc("/pricing", paymentController.GetProjectUsagePriceModel).Methods(http.MethodGet, http.MethodOptions)
if config.PricingPackagesEnabled { if config.PricingPackagesEnabled {
paymentsRouter.HandleFunc("/purchase-package", paymentController.PurchasePackage).Methods(http.MethodPost) paymentsRouter.HandleFunc("/purchase-package", paymentController.PurchasePackage).Methods(http.MethodPost, http.MethodOptions)
paymentsRouter.HandleFunc("/package-available", paymentController.PackageAvailable).Methods(http.MethodGet) paymentsRouter.HandleFunc("/package-available", paymentController.PackageAvailable).Methods(http.MethodGet, http.MethodOptions)
} }
bucketsController := consoleapi.NewBuckets(logger, service) bucketsController := consoleapi.NewBuckets(logger, service)
bucketsRouter := router.PathPrefix("/api/v0/buckets").Subrouter() bucketsRouter := router.PathPrefix("/api/v0/buckets").Subrouter()
bucketsRouter.Use(server.withCORS)
bucketsRouter.Use(server.withAuth) bucketsRouter.Use(server.withAuth)
bucketsRouter.HandleFunc("/bucket-names", bucketsController.AllBucketNames).Methods(http.MethodGet) bucketsRouter.HandleFunc("/bucket-names", bucketsController.AllBucketNames).Methods(http.MethodGet, http.MethodOptions)
apiKeysController := consoleapi.NewAPIKeys(logger, service) apiKeysController := consoleapi.NewAPIKeys(logger, service)
apiKeysRouter := router.PathPrefix("/api/v0/api-keys").Subrouter() apiKeysRouter := router.PathPrefix("/api/v0/api-keys").Subrouter()
apiKeysRouter.Use(server.withCORS)
apiKeysRouter.Use(server.withAuth) apiKeysRouter.Use(server.withAuth)
apiKeysRouter.HandleFunc("/delete-by-name", apiKeysController.DeleteByNameAndProjectID).Methods(http.MethodDelete) apiKeysRouter.HandleFunc("/delete-by-name", apiKeysController.DeleteByNameAndProjectID).Methods(http.MethodDelete, http.MethodOptions)
apiKeysRouter.HandleFunc("/api-key-names", apiKeysController.GetAllAPIKeyNames).Methods(http.MethodGet) apiKeysRouter.HandleFunc("/api-key-names", apiKeysController.GetAllAPIKeyNames).Methods(http.MethodGet, http.MethodOptions)
analyticsController := consoleapi.NewAnalytics(logger, service, server.analytics) analyticsController := consoleapi.NewAnalytics(logger, service, server.analytics)
analyticsRouter := router.PathPrefix("/api/v0/analytics").Subrouter() analyticsRouter := router.PathPrefix("/api/v0/analytics").Subrouter()
analyticsRouter.Use(server.withCORS)
analyticsRouter.Use(server.withAuth) analyticsRouter.Use(server.withAuth)
analyticsRouter.HandleFunc("/event", analyticsController.EventTriggered).Methods(http.MethodPost) analyticsRouter.HandleFunc("/event", analyticsController.EventTriggered).Methods(http.MethodPost, http.MethodOptions)
analyticsRouter.HandleFunc("/page", analyticsController.PageEventTriggered).Methods(http.MethodPost) analyticsRouter.HandleFunc("/page", analyticsController.PageEventTriggered).Methods(http.MethodPost, http.MethodOptions)
if server.config.StaticDir != "" { if server.config.StaticDir != "" {
oidc := oidc.NewEndpoint( oidc := oidc.NewEndpoint(
@ -356,7 +367,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
router.Handle("/oauth/v2/clients/{id}", server.withAuth(http.HandlerFunc(oidc.GetClient))).Methods(http.MethodGet) router.Handle("/oauth/v2/clients/{id}", server.withAuth(http.HandlerFunc(oidc.GetClient))).Methods(http.MethodGet)
fs := http.FileServer(http.Dir(server.config.StaticDir)) fs := http.FileServer(http.Dir(server.config.StaticDir))
router.PathPrefix("/static/").Handler(server.brotliMiddleware(http.StripPrefix("/static", fs))) router.PathPrefix("/static/").Handler(server.withCORS(server.brotliMiddleware(http.StripPrefix("/static", fs))))
router.HandleFunc("/invited", server.handleInvited) router.HandleFunc("/invited", server.handleInvited)
@ -367,9 +378,9 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
slashRouter.HandleFunc("/cancel-password-recovery", server.cancelPasswordRecoveryHandler) slashRouter.HandleFunc("/cancel-password-recovery", server.cancelPasswordRecoveryHandler)
if server.config.UseVuetifyProject { if server.config.UseVuetifyProject {
router.PathPrefix("/vuetifypoc").Handler(http.HandlerFunc(server.vuetifyAppHandler)) router.PathPrefix("/vuetifypoc").Handler(server.withCORS(http.HandlerFunc(server.vuetifyAppHandler)))
} }
router.PathPrefix("/").Handler(http.HandlerFunc(server.appHandler)) router.PathPrefix("/").Handler(server.withCORS(http.HandlerFunc(server.appHandler)))
} }
server.server = http.Server{ server.server = http.Server{
@ -506,6 +517,29 @@ func (server *Server) vuetifyAppHandler(w http.ResponseWriter, r *http.Request)
http.ServeContent(w, r, path, info.ModTime(), file) http.ServeContent(w, r, path, info.ModTime(), file)
} }
// withCORS handles setting CORS-related headers on an http request.
func (server *Server) withCORS(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", strings.Trim(server.config.ExternalAddress, "/"))
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
w.Header().Set("Access-Control-Expose-Headers", "*, Authorization")
if r.Method == http.MethodOptions {
match := &mux.RouteMatch{}
if server.router.Match(r, match) {
methods, err := match.Route.GetMethods()
if err == nil && len(methods) > 0 {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ", "))
}
}
return
}
handler.ServeHTTP(w, r)
})
}
// withAuth performs initial authorization before every request. // withAuth performs initial authorization before every request.
func (server *Server) withAuth(handler http.Handler) http.Handler { func (server *Server) withAuth(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {