storj/private/web/ratelimiter.go

160 lines
4.3 KiB
Go
Raw Normal View History

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package web
import (
"context"
"net"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiterConfig configures a RateLimiter.
type RateLimiterConfig struct {
Duration time.Duration `help:"the rate at which request are allowed" default:"5m"`
Burst int `help:"number of events before the limit kicks in" default:"5" testDefault:"3"`
NumLimits int `help:"number of clients whose rate limits we store" default:"1000" testDefault:"10"`
}
// RateLimiter imposes a rate limit per key.
type RateLimiter struct {
config RateLimiterConfig
mu sync.Mutex
limits map[string]*userLimit
keyFunc func(*http.Request) (string, error)
}
// userLimit is the per-key limiter.
type userLimit struct {
limiter *rate.Limiter
lastSeen time.Time
}
// NewIPRateLimiter constructs a RateLimiter that limits based on IP address.
func NewIPRateLimiter(config RateLimiterConfig) *RateLimiter {
return NewRateLimiter(config, GetRequestIP)
}
// NewRateLimiter constructs a RateLimiter.
func NewRateLimiter(config RateLimiterConfig, keyFunc func(*http.Request) (string, error)) *RateLimiter {
return &RateLimiter{
config: config,
limits: make(map[string]*userLimit),
keyFunc: keyFunc,
}
}
// Run occasionally cleans old rate-limiting data, until context cancel.
func (rl *RateLimiter) Run(ctx context.Context) {
cleanupTicker := time.NewTicker(rl.config.Duration)
defer cleanupTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-cleanupTicker.C:
rl.cleanupLimiters()
}
}
}
// cleanupLimiters removes old rate limits to free memory.
func (rl *RateLimiter) cleanupLimiters() {
rl.mu.Lock()
defer rl.mu.Unlock()
for k, v := range rl.limits {
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.limits, k)
}
}
}
// Limit applies per-key rate limiting as an HTTP Handler.
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key, err := rl.keyFunc(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
limit := rl.getUserLimit(key)
if !limit.Allow() {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// GetRequestIP gets the original IP address of the request by handling the request headers.
func GetRequestIP(r *http.Request) (ip string, err error) {
realIP := r.Header.Get("X-REAL-IP")
if realIP != "" {
return realIP, nil
}
forwardedIPs := r.Header.Get("X-FORWARDED-FOR")
if forwardedIPs != "" {
ips := strings.Split(forwardedIPs, ", ")
if len(ips) > 0 {
return ips[0], nil
}
}
ip, _, err = net.SplitHostPort(r.RemoteAddr)
return ip, err
}
// getUserLimit returns a rate limiter for a key.
func (rl *RateLimiter) getUserLimit(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.limits[key]
if !exists {
if len(rl.limits) >= rl.config.NumLimits {
// Tracking only N limits prevents an out-of-memory DOS attack
// Returning StatusTooManyRequests would be just as bad
// The least-bad option may be to remove the oldest key
oldestKey := ""
var oldestTime *time.Time
for key, v := range rl.limits {
// while we're looping, we'd prefer to just delete expired records
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.limits, key)
}
// but we're prepared to delete the oldest non-expired
if oldestTime == nil || v.lastSeen.Before(*oldestTime) {
oldestTime = &v.lastSeen
oldestKey = key
}
}
// only delete the oldest non-expired if there's still an issue
if oldestKey != "" && len(rl.limits) >= rl.config.NumLimits {
delete(rl.limits, oldestKey)
}
}
limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst)
rl.limits[key] = &userLimit{limiter, time.Now()}
return limiter
}
v.lastSeen = time.Now()
return v.limiter
}
// Burst returns the number of events that happen before the rate limit.
func (rl *RateLimiter) Burst() int {
return rl.config.Burst
}
// Duration returns the amount of time required between events.
func (rl *RateLimiter) Duration() time.Duration {
return rl.config.Duration
}