// Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. package web_test import ( "context" "net/http" "net/http/httptest" "testing" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "storj.io/common/testcontext" "storj.io/private/cfgstruct" "storj.io/storj/private/web" ) func TestNewIPRateLimiter(t *testing.T) { //create a rate limiter with defaults except NumLimits = 2 config := web.IPRateLimiterConfig{} cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults()) config.NumLimits = 2 rateLimiter := web.NewIPRateLimiter(config) //run ratelimiter cleanup until end of test ctx := testcontext.New(t) defer ctx.Cleanup() ctx2, cancel := context.WithCancel(ctx) defer cancel() ctx.Go(func() error { rateLimiter.Run(ctx2) return nil }) //make the default HTTP handler return StatusOK handler := rateLimiter.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) //expect burst number of successes testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler) //expect similar results for a different IP testWithAddress(t, "127.0.0.1:5000", rateLimiter.Burst(), handler) //expect similar results for a different IP testWithAddress(t, "127.0.0.100:5000", rateLimiter.Burst(), handler) //expect original IP to work again because numLimits == 2 testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler) } func testWithAddress(t *testing.T, remoteAddress string, burst int, handler http.Handler) { //create HTTP request req, err := http.NewRequest("GET", "", nil) require.NoError(t, err) req.RemoteAddr = remoteAddress //expect burst number of successes for x := 0; x < burst; x++ { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, rr.Code, http.StatusOK, remoteAddress) } //then expect failure rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, rr.Code, http.StatusTooManyRequests, remoteAddress) }