satellite/console,private/web: Rate limit coupon code application
Rate limits application of coupon codes by user ID to prevent brute forcing. Refactors the rate limiter to allow limiting based on arbitrary criteria and not just by IP. Change-Id: I99d6749bd5b5e47d7e1aeb0314e363a8e7259dba
This commit is contained in:
parent
0344790c20
commit
6a6cc28fc1
@ -14,36 +14,43 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// IPRateLimiterConfig configures an IPRateLimiter.
|
||||
type IPRateLimiterConfig struct {
|
||||
// RateLimiterConfig configures a RateLimiter.
|
||||
type RateLimiterConfig struct {
|
||||
Duration time.Duration `help:"the rate at which request are allowed" default:"5m"`
|
||||
Burst int `help:"number of events before the limit kicks in" default:"5" testDefault:"3"`
|
||||
NumLimits int `help:"number of IPs whose rate limits we store" default:"1000" testDefault:"10"`
|
||||
NumLimits int `help:"number of clients whose rate limits we store" default:"1000" testDefault:"10"`
|
||||
}
|
||||
|
||||
// IPRateLimiter imposes a rate limit per HTTP user IP.
|
||||
type IPRateLimiter struct {
|
||||
config IPRateLimiterConfig
|
||||
mu sync.Mutex
|
||||
ipLimits map[string]*userLimit
|
||||
// RateLimiter imposes a rate limit per key.
|
||||
type RateLimiter struct {
|
||||
config RateLimiterConfig
|
||||
mu sync.Mutex
|
||||
limits map[string]*userLimit
|
||||
keyFunc func(*http.Request) (string, error)
|
||||
}
|
||||
|
||||
// userLimit is the per-IP limiter.
|
||||
// userLimit is the per-key limiter.
|
||||
type userLimit struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// NewIPRateLimiter constructs an IPRateLimiter.
|
||||
func NewIPRateLimiter(config IPRateLimiterConfig) *IPRateLimiter {
|
||||
return &IPRateLimiter{
|
||||
config: config,
|
||||
ipLimits: make(map[string]*userLimit),
|
||||
// NewIPRateLimiter constructs a RateLimiter that limits based on IP address.
|
||||
func NewIPRateLimiter(config RateLimiterConfig) *RateLimiter {
|
||||
return NewRateLimiter(config, GetRequestIP)
|
||||
}
|
||||
|
||||
// NewRateLimiter constructs a RateLimiter.
|
||||
func NewRateLimiter(config RateLimiterConfig, keyFunc func(*http.Request) (string, error)) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
config: config,
|
||||
limits: make(map[string]*userLimit),
|
||||
keyFunc: keyFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// Run occasionally cleans old rate-limiting data, until context cancel.
|
||||
func (rl *IPRateLimiter) Run(ctx context.Context) {
|
||||
func (rl *RateLimiter) Run(ctx context.Context) {
|
||||
cleanupTicker := time.NewTicker(rl.config.Duration)
|
||||
defer cleanupTicker.Stop()
|
||||
for {
|
||||
@ -57,26 +64,26 @@ func (rl *IPRateLimiter) Run(ctx context.Context) {
|
||||
}
|
||||
|
||||
// cleanupLimiters removes old rate limits to free memory.
|
||||
func (rl *IPRateLimiter) cleanupLimiters() {
|
||||
func (rl *RateLimiter) cleanupLimiters() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
for ip, v := range rl.ipLimits {
|
||||
for k, v := range rl.limits {
|
||||
if time.Since(v.lastSeen) > rl.config.Duration {
|
||||
delete(rl.ipLimits, ip)
|
||||
delete(rl.limits, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Limit applies a per IP rate limiting as an HTTP Handler.
|
||||
func (rl *IPRateLimiter) Limit(next http.Handler) http.Handler {
|
||||
// Limit applies per-key rate limiting as an HTTP Handler.
|
||||
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip, err := GetRequestIP(r)
|
||||
key, err := rl.keyFunc(r)
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
ipLimit := rl.getUserLimit(ip)
|
||||
if !ipLimit.Allow() {
|
||||
limit := rl.getUserLimit(key)
|
||||
if !limit.Allow() {
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
@ -104,37 +111,37 @@ func GetRequestIP(r *http.Request) (ip string, err error) {
|
||||
return ip, err
|
||||
}
|
||||
|
||||
// getUserLimit returns a rate limiter for an IP.
|
||||
func (rl *IPRateLimiter) getUserLimit(ip string) *rate.Limiter {
|
||||
// getUserLimit returns a rate limiter for a key.
|
||||
func (rl *RateLimiter) getUserLimit(key string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, exists := rl.ipLimits[ip]
|
||||
v, exists := rl.limits[key]
|
||||
if !exists {
|
||||
if len(rl.ipLimits) >= rl.config.NumLimits {
|
||||
if len(rl.limits) >= rl.config.NumLimits {
|
||||
// Tracking only N limits prevents an out-of-memory DOS attack
|
||||
// Returning StatusTooManyRequests would be just as bad
|
||||
// The least-bad option may be to remove the oldest key
|
||||
oldestKey := ""
|
||||
var oldestTime *time.Time
|
||||
for ip, v := range rl.ipLimits {
|
||||
for key, v := range rl.limits {
|
||||
// while we're looping, we'd prefer to just delete expired records
|
||||
if time.Since(v.lastSeen) > rl.config.Duration {
|
||||
delete(rl.ipLimits, ip)
|
||||
delete(rl.limits, key)
|
||||
}
|
||||
// but we're prepared to delete the oldest non-expired
|
||||
if oldestTime == nil || v.lastSeen.Before(*oldestTime) {
|
||||
oldestTime = &v.lastSeen
|
||||
oldestKey = ip
|
||||
oldestKey = key
|
||||
}
|
||||
}
|
||||
// only delete the oldest non-expired if there's still an issue
|
||||
if oldestKey != "" && len(rl.ipLimits) >= rl.config.NumLimits {
|
||||
delete(rl.ipLimits, oldestKey)
|
||||
if oldestKey != "" && len(rl.limits) >= rl.config.NumLimits {
|
||||
delete(rl.limits, oldestKey)
|
||||
}
|
||||
}
|
||||
limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst)
|
||||
rl.ipLimits[ip] = &userLimit{limiter, time.Now()}
|
||||
rl.limits[key] = &userLimit{limiter, time.Now()}
|
||||
return limiter
|
||||
}
|
||||
v.lastSeen = time.Now()
|
||||
@ -142,11 +149,11 @@ func (rl *IPRateLimiter) getUserLimit(ip string) *rate.Limiter {
|
||||
}
|
||||
|
||||
// Burst returns the number of events that happen before the rate limit.
|
||||
func (rl *IPRateLimiter) Burst() int {
|
||||
func (rl *RateLimiter) Burst() int {
|
||||
return rl.config.Burst
|
||||
}
|
||||
|
||||
// Duration returns the amount of time required between events.
|
||||
func (rl *IPRateLimiter) Duration() time.Duration {
|
||||
func (rl *RateLimiter) Duration() time.Duration {
|
||||
return rl.config.Duration
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
func TestNewIPRateLimiter(t *testing.T) {
|
||||
// create a rate limiter with defaults except NumLimits = 2
|
||||
config := web.IPRateLimiterConfig{}
|
||||
config := web.RateLimiterConfig{}
|
||||
cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults())
|
||||
config.NumLimits = 2
|
||||
rateLimiter := web.NewIPRateLimiter(config)
|
||||
|
@ -91,7 +91,8 @@ type Config struct {
|
||||
LinksharingURL string `help:"url link for linksharing requests" default:"https://link.us1.storjshare.io"`
|
||||
PathwayOverviewEnabled bool `help:"indicates if the overview onboarding step should render with pathways" default:"true"`
|
||||
|
||||
RateLimit web.IPRateLimiterConfig
|
||||
// RateLimit defines the configuration for the IP and userID rate limiters.
|
||||
RateLimit web.RateLimiterConfig
|
||||
|
||||
console.Config
|
||||
}
|
||||
@ -141,11 +142,12 @@ type Server struct {
|
||||
partners *rewards.PartnersService
|
||||
analytics *analytics.Service
|
||||
|
||||
listener net.Listener
|
||||
server http.Server
|
||||
cookieAuth *consolewebauth.CookieAuth
|
||||
rateLimiter *web.IPRateLimiter
|
||||
nodeURL storj.NodeURL
|
||||
listener net.Listener
|
||||
server http.Server
|
||||
cookieAuth *consolewebauth.CookieAuth
|
||||
ipRateLimiter *web.RateLimiter
|
||||
userIDRateLimiter *web.RateLimiter
|
||||
nodeURL storj.NodeURL
|
||||
|
||||
stripePublicKey string
|
||||
|
||||
@ -163,17 +165,18 @@ type Server struct {
|
||||
// NewServer creates new instance of console server.
|
||||
func NewServer(logger *zap.Logger, config Config, service *console.Service, mailService *mailservice.Service, partners *rewards.PartnersService, analytics *analytics.Service, listener net.Listener, stripePublicKey string, pricing paymentsconfig.PricingValues, nodeURL storj.NodeURL) *Server {
|
||||
server := Server{
|
||||
log: logger,
|
||||
config: config,
|
||||
listener: listener,
|
||||
service: service,
|
||||
mailService: mailService,
|
||||
partners: partners,
|
||||
analytics: analytics,
|
||||
stripePublicKey: stripePublicKey,
|
||||
rateLimiter: web.NewIPRateLimiter(config.RateLimit),
|
||||
nodeURL: nodeURL,
|
||||
pricing: pricing,
|
||||
log: logger,
|
||||
config: config,
|
||||
listener: listener,
|
||||
service: service,
|
||||
mailService: mailService,
|
||||
partners: partners,
|
||||
analytics: analytics,
|
||||
stripePublicKey: stripePublicKey,
|
||||
ipRateLimiter: web.NewIPRateLimiter(config.RateLimit),
|
||||
userIDRateLimiter: NewUserIDRateLimiter(config.RateLimit),
|
||||
nodeURL: nodeURL,
|
||||
pricing: pricing,
|
||||
}
|
||||
|
||||
logger.Debug("Starting Satellite UI.", zap.Stringer("Address", server.listener.Addr()))
|
||||
@ -225,11 +228,11 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, mail
|
||||
authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost)
|
||||
authRouter.HandleFunc("/logout", authController.Logout).Methods(http.MethodPost)
|
||||
authRouter.Handle("/token", server.rateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/register", server.rateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions)
|
||||
authRouter.Handle("/forgot-password/{email}", server.rateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/resend-email/{id}", server.rateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/reset-password", server.rateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/register", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions)
|
||||
authRouter.Handle("/forgot-password/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/resend-email/{id}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost)
|
||||
authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost)
|
||||
|
||||
paymentController := consoleapi.NewPayments(logger, service)
|
||||
paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter()
|
||||
@ -243,7 +246,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, mail
|
||||
paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost)
|
||||
paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet)
|
||||
paymentsRouter.HandleFunc("/tokens/deposit", paymentController.TokenDeposit).Methods(http.MethodPost)
|
||||
paymentsRouter.HandleFunc("/coupon/apply", paymentController.ApplyCouponCode).Methods(http.MethodPatch)
|
||||
paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch)
|
||||
paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet)
|
||||
|
||||
bucketsController := consoleapi.NewBuckets(logger, service)
|
||||
@ -299,7 +302,7 @@ func (server *Server) Run(ctx context.Context) (err error) {
|
||||
return server.server.Shutdown(context.Background())
|
||||
})
|
||||
group.Go(func() error {
|
||||
server.rateLimiter.Run(ctx)
|
||||
server.ipRateLimiter.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
@ -782,3 +785,14 @@ func (server *Server) initializeTemplates() (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUserIDRateLimiter constructs a RateLimiter that limits based on user ID.
|
||||
func NewUserIDRateLimiter(config web.RateLimiterConfig) *web.RateLimiter {
|
||||
return web.NewRateLimiter(config, func(r *http.Request) (string, error) {
|
||||
auth, err := console.GetAuth(r.Context())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return auth.User.ID.String(), nil
|
||||
})
|
||||
}
|
||||
|
@ -4,13 +4,18 @@
|
||||
package consoleweb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/console"
|
||||
)
|
||||
|
||||
@ -58,3 +63,70 @@ func TestActivationRouting(t *testing.T) {
|
||||
checkActivationRedirect("Activation - Used Token", loginURL+"?activated=false")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserIDRateLimiter(t *testing.T) {
|
||||
numLimits := 2
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Console.RateLimit.NumLimits = numLimits
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
sat := planet.Satellites[0]
|
||||
|
||||
applyCouponStatus := func(token string) int {
|
||||
urlLink := "http://" + sat.API.Console.Listener.Addr().String() + "/api/v0/payments/coupon/apply"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, urlLink, bytes.NewBufferString("PROMO_CODE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "_tokenKey",
|
||||
Path: "/",
|
||||
Value: token,
|
||||
Expires: time.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
|
||||
result, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, result.Body.Close())
|
||||
|
||||
return result.StatusCode
|
||||
}
|
||||
|
||||
var firstToken string
|
||||
for userNum := 1; userNum <= numLimits+1; userNum++ {
|
||||
t.Run(fmt.Sprintf("TestUserIDRateLimit_%d", userNum), func(t *testing.T) {
|
||||
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||
FullName: fmt.Sprintf("Test User %d", userNum),
|
||||
Email: fmt.Sprintf("test%d@mail.test", userNum),
|
||||
}, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// sat.AddUser sets password to full name.
|
||||
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||
require.NoError(t, err)
|
||||
|
||||
if userNum == 1 {
|
||||
firstToken = token
|
||||
}
|
||||
|
||||
// Expect burst number of successes.
|
||||
for burstNum := 0; burstNum < sat.Config.Console.RateLimit.Burst; burstNum++ {
|
||||
require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(token))
|
||||
}
|
||||
|
||||
// Expect failure.
|
||||
require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(token))
|
||||
})
|
||||
}
|
||||
|
||||
// Expect original user to work again because numLimits == 2.
|
||||
for burstNum := 0; burstNum < sat.Config.Console.RateLimit.Burst; burstNum++ {
|
||||
require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(firstToken))
|
||||
}
|
||||
require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(firstToken))
|
||||
})
|
||||
}
|
||||
|
2
scripts/testdata/satellite-config.yaml.lock
vendored
2
scripts/testdata/satellite-config.yaml.lock
vendored
@ -148,7 +148,7 @@ compensation.withheld-percents: 75,75,75,50,50,50,25,25,25,0,0,0,0,0,0
|
||||
# the rate at which request are allowed
|
||||
# console.rate-limit.duration: 5m0s
|
||||
|
||||
# number of IPs whose rate limits we store
|
||||
# number of clients whose rate limits we store
|
||||
# console.rate-limit.num-limits: 1000
|
||||
|
||||
# whether or not reCAPTCHA is enabled for user registration
|
||||
|
@ -13,6 +13,7 @@ import {
|
||||
} from '@/types/payments';
|
||||
import { HttpClient } from '@/utils/httpClient';
|
||||
import { Time } from '@/utils/time';
|
||||
import { ErrorTooManyRequests } from './errors/ErrorTooManyRequests';
|
||||
|
||||
/**
|
||||
* PaymentsHttpApi is a http implementation of Payments API.
|
||||
@ -285,10 +286,14 @@ export class PaymentsHttpApi implements PaymentsApi {
|
||||
);
|
||||
}
|
||||
|
||||
if (response.status === 401) {
|
||||
throw new ErrorUnauthorized();
|
||||
switch (response.status) {
|
||||
case 429:
|
||||
throw new ErrorTooManyRequests('You\'ve exceeded limit of attempts, try again in 5 minutes');
|
||||
case 401:
|
||||
throw new ErrorUnauthorized(errMsg);
|
||||
default:
|
||||
throw new Error(errMsg);
|
||||
}
|
||||
throw new Error(errMsg);
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user