satellite/console,private/web: Rate limit coupon code application

Rate limits application of coupon codes by user ID to prevent
brute forcing. Refactors the rate limiter to allow limiting based
on arbitrary criteria and not just by IP.

Change-Id: I99d6749bd5b5e47d7e1aeb0314e363a8e7259dba
This commit is contained in:
Jeremy Wharton 2021-08-17 14:38:34 -05:00 committed by Jeremy Wharton
parent 0344790c20
commit 6a6cc28fc1
6 changed files with 162 additions and 64 deletions

View File

@ -14,36 +14,43 @@ import (
"golang.org/x/time/rate"
)
// IPRateLimiterConfig configures an IPRateLimiter.
type IPRateLimiterConfig struct {
// RateLimiterConfig configures a RateLimiter.
type RateLimiterConfig struct {
Duration time.Duration `help:"the rate at which request are allowed" default:"5m"`
Burst int `help:"number of events before the limit kicks in" default:"5" testDefault:"3"`
NumLimits int `help:"number of IPs whose rate limits we store" default:"1000" testDefault:"10"`
NumLimits int `help:"number of clients whose rate limits we store" default:"1000" testDefault:"10"`
}
// IPRateLimiter imposes a rate limit per HTTP user IP.
type IPRateLimiter struct {
config IPRateLimiterConfig
mu sync.Mutex
ipLimits map[string]*userLimit
// RateLimiter imposes a rate limit per key.
type RateLimiter struct {
config RateLimiterConfig
mu sync.Mutex
limits map[string]*userLimit
keyFunc func(*http.Request) (string, error)
}
// userLimit is the per-IP limiter.
// userLimit is the per-key limiter.
type userLimit struct {
limiter *rate.Limiter
lastSeen time.Time
}
// NewIPRateLimiter constructs an IPRateLimiter.
func NewIPRateLimiter(config IPRateLimiterConfig) *IPRateLimiter {
return &IPRateLimiter{
config: config,
ipLimits: make(map[string]*userLimit),
// NewIPRateLimiter constructs a RateLimiter that limits based on IP address.
func NewIPRateLimiter(config RateLimiterConfig) *RateLimiter {
return NewRateLimiter(config, GetRequestIP)
}
// NewRateLimiter constructs a RateLimiter.
func NewRateLimiter(config RateLimiterConfig, keyFunc func(*http.Request) (string, error)) *RateLimiter {
return &RateLimiter{
config: config,
limits: make(map[string]*userLimit),
keyFunc: keyFunc,
}
}
// Run occasionally cleans old rate-limiting data, until context cancel.
func (rl *IPRateLimiter) Run(ctx context.Context) {
func (rl *RateLimiter) Run(ctx context.Context) {
cleanupTicker := time.NewTicker(rl.config.Duration)
defer cleanupTicker.Stop()
for {
@ -57,26 +64,26 @@ func (rl *IPRateLimiter) Run(ctx context.Context) {
}
// cleanupLimiters removes old rate limits to free memory.
func (rl *IPRateLimiter) cleanupLimiters() {
func (rl *RateLimiter) cleanupLimiters() {
rl.mu.Lock()
defer rl.mu.Unlock()
for ip, v := range rl.ipLimits {
for k, v := range rl.limits {
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.ipLimits, ip)
delete(rl.limits, k)
}
}
}
// Limit applies a per IP rate limiting as an HTTP Handler.
func (rl *IPRateLimiter) Limit(next http.Handler) http.Handler {
// Limit applies per-key rate limiting as an HTTP Handler.
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, err := GetRequestIP(r)
key, err := rl.keyFunc(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
ipLimit := rl.getUserLimit(ip)
if !ipLimit.Allow() {
limit := rl.getUserLimit(key)
if !limit.Allow() {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
@ -104,37 +111,37 @@ func GetRequestIP(r *http.Request) (ip string, err error) {
return ip, err
}
// getUserLimit returns a rate limiter for an IP.
func (rl *IPRateLimiter) getUserLimit(ip string) *rate.Limiter {
// getUserLimit returns a rate limiter for a key.
func (rl *RateLimiter) getUserLimit(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.ipLimits[ip]
v, exists := rl.limits[key]
if !exists {
if len(rl.ipLimits) >= rl.config.NumLimits {
if len(rl.limits) >= rl.config.NumLimits {
// Tracking only N limits prevents an out-of-memory DOS attack
// Returning StatusTooManyRequests would be just as bad
// The least-bad option may be to remove the oldest key
oldestKey := ""
var oldestTime *time.Time
for ip, v := range rl.ipLimits {
for key, v := range rl.limits {
// while we're looping, we'd prefer to just delete expired records
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.ipLimits, ip)
delete(rl.limits, key)
}
// but we're prepared to delete the oldest non-expired
if oldestTime == nil || v.lastSeen.Before(*oldestTime) {
oldestTime = &v.lastSeen
oldestKey = ip
oldestKey = key
}
}
// only delete the oldest non-expired if there's still an issue
if oldestKey != "" && len(rl.ipLimits) >= rl.config.NumLimits {
delete(rl.ipLimits, oldestKey)
if oldestKey != "" && len(rl.limits) >= rl.config.NumLimits {
delete(rl.limits, oldestKey)
}
}
limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst)
rl.ipLimits[ip] = &userLimit{limiter, time.Now()}
rl.limits[key] = &userLimit{limiter, time.Now()}
return limiter
}
v.lastSeen = time.Now()
@ -142,11 +149,11 @@ func (rl *IPRateLimiter) getUserLimit(ip string) *rate.Limiter {
}
// Burst returns the number of events that happen before the rate limit.
func (rl *IPRateLimiter) Burst() int {
func (rl *RateLimiter) Burst() int {
return rl.config.Burst
}
// Duration returns the amount of time required between events.
func (rl *IPRateLimiter) Duration() time.Duration {
func (rl *RateLimiter) Duration() time.Duration {
return rl.config.Duration
}

View File

@ -20,7 +20,7 @@ import (
func TestNewIPRateLimiter(t *testing.T) {
// create a rate limiter with defaults except NumLimits = 2
config := web.IPRateLimiterConfig{}
config := web.RateLimiterConfig{}
cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults())
config.NumLimits = 2
rateLimiter := web.NewIPRateLimiter(config)

View File

@ -91,7 +91,8 @@ type Config struct {
LinksharingURL string `help:"url link for linksharing requests" default:"https://link.us1.storjshare.io"`
PathwayOverviewEnabled bool `help:"indicates if the overview onboarding step should render with pathways" default:"true"`
RateLimit web.IPRateLimiterConfig
// RateLimit defines the configuration for the IP and userID rate limiters.
RateLimit web.RateLimiterConfig
console.Config
}
@ -141,11 +142,12 @@ type Server struct {
partners *rewards.PartnersService
analytics *analytics.Service
listener net.Listener
server http.Server
cookieAuth *consolewebauth.CookieAuth
rateLimiter *web.IPRateLimiter
nodeURL storj.NodeURL
listener net.Listener
server http.Server
cookieAuth *consolewebauth.CookieAuth
ipRateLimiter *web.RateLimiter
userIDRateLimiter *web.RateLimiter
nodeURL storj.NodeURL
stripePublicKey string
@ -163,17 +165,18 @@ type Server struct {
// NewServer creates new instance of console server.
func NewServer(logger *zap.Logger, config Config, service *console.Service, mailService *mailservice.Service, partners *rewards.PartnersService, analytics *analytics.Service, listener net.Listener, stripePublicKey string, pricing paymentsconfig.PricingValues, nodeURL storj.NodeURL) *Server {
server := Server{
log: logger,
config: config,
listener: listener,
service: service,
mailService: mailService,
partners: partners,
analytics: analytics,
stripePublicKey: stripePublicKey,
rateLimiter: web.NewIPRateLimiter(config.RateLimit),
nodeURL: nodeURL,
pricing: pricing,
log: logger,
config: config,
listener: listener,
service: service,
mailService: mailService,
partners: partners,
analytics: analytics,
stripePublicKey: stripePublicKey,
ipRateLimiter: web.NewIPRateLimiter(config.RateLimit),
userIDRateLimiter: NewUserIDRateLimiter(config.RateLimit),
nodeURL: nodeURL,
pricing: pricing,
}
logger.Debug("Starting Satellite UI.", zap.Stringer("Address", server.listener.Addr()))
@ -225,11 +228,11 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, mail
authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost)
authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost)
authRouter.HandleFunc("/logout", authController.Logout).Methods(http.MethodPost)
authRouter.Handle("/token", server.rateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
authRouter.Handle("/register", server.rateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/forgot-password/{email}", server.rateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost)
authRouter.Handle("/resend-email/{id}", server.rateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost)
authRouter.Handle("/reset-password", server.rateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost)
authRouter.Handle("/token", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
authRouter.Handle("/register", server.ipRateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost, http.MethodOptions)
authRouter.Handle("/forgot-password/{email}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ForgotPassword))).Methods(http.MethodPost)
authRouter.Handle("/resend-email/{id}", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResendEmail))).Methods(http.MethodPost)
authRouter.Handle("/reset-password", server.ipRateLimiter.Limit(http.HandlerFunc(authController.ResetPassword))).Methods(http.MethodPost)
paymentController := consoleapi.NewPayments(logger, service)
paymentsRouter := router.PathPrefix("/api/v0/payments").Subrouter()
@ -243,7 +246,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, mail
paymentsRouter.HandleFunc("/account", paymentController.SetupAccount).Methods(http.MethodPost)
paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet)
paymentsRouter.HandleFunc("/tokens/deposit", paymentController.TokenDeposit).Methods(http.MethodPost)
paymentsRouter.HandleFunc("/coupon/apply", paymentController.ApplyCouponCode).Methods(http.MethodPatch)
paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch)
paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet)
bucketsController := consoleapi.NewBuckets(logger, service)
@ -299,7 +302,7 @@ func (server *Server) Run(ctx context.Context) (err error) {
return server.server.Shutdown(context.Background())
})
group.Go(func() error {
server.rateLimiter.Run(ctx)
server.ipRateLimiter.Run(ctx)
return nil
})
group.Go(func() error {
@ -782,3 +785,14 @@ func (server *Server) initializeTemplates() (err error) {
return nil
}
// NewUserIDRateLimiter constructs a RateLimiter that limits based on user ID.
func NewUserIDRateLimiter(config web.RateLimiterConfig) *web.RateLimiter {
return web.NewRateLimiter(config, func(r *http.Request) (string, error) {
auth, err := console.GetAuth(r.Context())
if err != nil {
return "", err
}
return auth.User.ID.String(), nil
})
}

View File

@ -4,13 +4,18 @@
package consoleweb_test
import (
"bytes"
"fmt"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"storj.io/common/testcontext"
"storj.io/storj/private/testplanet"
"storj.io/storj/satellite"
"storj.io/storj/satellite/console"
)
@ -58,3 +63,70 @@ func TestActivationRouting(t *testing.T) {
checkActivationRedirect("Activation - Used Token", loginURL+"?activated=false")
})
}
func TestUserIDRateLimiter(t *testing.T) {
numLimits := 2
testplanet.Run(t, testplanet.Config{
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
Reconfigure: testplanet.Reconfigure{
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
config.Console.RateLimit.NumLimits = numLimits
},
},
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
sat := planet.Satellites[0]
applyCouponStatus := func(token string) int {
urlLink := "http://" + sat.API.Console.Listener.Addr().String() + "/api/v0/payments/coupon/apply"
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, urlLink, bytes.NewBufferString("PROMO_CODE"))
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "_tokenKey",
Path: "/",
Value: token,
Expires: time.Now().AddDate(0, 0, 1),
})
result, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.NoError(t, result.Body.Close())
return result.StatusCode
}
var firstToken string
for userNum := 1; userNum <= numLimits+1; userNum++ {
t.Run(fmt.Sprintf("TestUserIDRateLimit_%d", userNum), func(t *testing.T) {
user, err := sat.AddUser(ctx, console.CreateUser{
FullName: fmt.Sprintf("Test User %d", userNum),
Email: fmt.Sprintf("test%d@mail.test", userNum),
}, 1)
require.NoError(t, err)
// sat.AddUser sets password to full name.
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
require.NoError(t, err)
if userNum == 1 {
firstToken = token
}
// Expect burst number of successes.
for burstNum := 0; burstNum < sat.Config.Console.RateLimit.Burst; burstNum++ {
require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(token))
}
// Expect failure.
require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(token))
})
}
// Expect original user to work again because numLimits == 2.
for burstNum := 0; burstNum < sat.Config.Console.RateLimit.Burst; burstNum++ {
require.NotEqual(t, http.StatusTooManyRequests, applyCouponStatus(firstToken))
}
require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(firstToken))
})
}

View File

@ -148,7 +148,7 @@ compensation.withheld-percents: 75,75,75,50,50,50,25,25,25,0,0,0,0,0,0
# the rate at which request are allowed
# console.rate-limit.duration: 5m0s
# number of IPs whose rate limits we store
# number of clients whose rate limits we store
# console.rate-limit.num-limits: 1000
# whether or not reCAPTCHA is enabled for user registration

View File

@ -13,6 +13,7 @@ import {
} from '@/types/payments';
import { HttpClient } from '@/utils/httpClient';
import { Time } from '@/utils/time';
import { ErrorTooManyRequests } from './errors/ErrorTooManyRequests';
/**
* PaymentsHttpApi is a http implementation of Payments API.
@ -285,10 +286,14 @@ export class PaymentsHttpApi implements PaymentsApi {
);
}
if (response.status === 401) {
throw new ErrorUnauthorized();
switch (response.status) {
case 429:
throw new ErrorTooManyRequests('You\'ve exceeded limit of attempts, try again in 5 minutes');
case 401:
throw new ErrorUnauthorized(errMsg);
default:
throw new Error(errMsg);
}
throw new Error(errMsg);
}
/**