storj/private/web/ratelimiter_test.go

71 lines
2.1 KiB
Go
Raw Normal View History

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package web_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/private/cfgstruct"
"storj.io/storj/private/web"
)
func TestNewIPRateLimiter(t *testing.T) {
// create a rate limiter with defaults except NumLimits = 2
config := web.RateLimiterConfig{}
cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults())
config.NumLimits = 2
rateLimiter := web.NewIPRateLimiter(config)
// run ratelimiter cleanup until end of test
ctx := testcontext.New(t)
defer ctx.Cleanup()
ctx2, cancel := context.WithCancel(ctx)
defer cancel()
ctx.Go(func() error {
rateLimiter.Run(ctx2)
return nil
})
// make the default HTTP handler return StatusOK
handler := rateLimiter.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// expect burst number of successes
testWithAddress(ctx, t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
// expect similar results for a different IP
testWithAddress(ctx, t, "127.0.0.1:5000", rateLimiter.Burst(), handler)
// expect similar results for a different IP
testWithAddress(ctx, t, "127.0.0.100:5000", rateLimiter.Burst(), handler)
// expect original IP to work again because numLimits == 2
testWithAddress(ctx, t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
}
func testWithAddress(ctx context.Context, t *testing.T, remoteAddress string, burst int, handler http.Handler) {
// create HTTP request
req, err := http.NewRequestWithContext(ctx, "GET", "", nil)
require.NoError(t, err)
req.RemoteAddr = remoteAddress
// expect burst number of successes
for x := 0; x < burst; x++ {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK, remoteAddress)
}
// then expect failure
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusTooManyRequests, remoteAddress)
}