2268cc1df3
Change-Id: Ia01404dbb6bdd19a146fa10ff7302e08f87a8c95
71 lines
2.0 KiB
Go
71 lines
2.0 KiB
Go
// Copyright (C) 2019 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
package web_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/spf13/pflag"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"storj.io/common/testcontext"
|
|
"storj.io/private/cfgstruct"
|
|
"storj.io/storj/private/web"
|
|
)
|
|
|
|
func TestNewIPRateLimiter(t *testing.T) {
|
|
// create a rate limiter with defaults except NumLimits = 2
|
|
config := web.IPRateLimiterConfig{}
|
|
cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults())
|
|
config.NumLimits = 2
|
|
rateLimiter := web.NewIPRateLimiter(config)
|
|
|
|
// run ratelimiter cleanup until end of test
|
|
ctx := testcontext.New(t)
|
|
defer ctx.Cleanup()
|
|
ctx2, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
ctx.Go(func() error {
|
|
rateLimiter.Run(ctx2)
|
|
return nil
|
|
})
|
|
|
|
// make the default HTTP handler return StatusOK
|
|
handler := rateLimiter.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// expect burst number of successes
|
|
testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
|
|
// expect similar results for a different IP
|
|
testWithAddress(t, "127.0.0.1:5000", rateLimiter.Burst(), handler)
|
|
// expect similar results for a different IP
|
|
testWithAddress(t, "127.0.0.100:5000", rateLimiter.Burst(), handler)
|
|
// expect original IP to work again because numLimits == 2
|
|
testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
|
|
}
|
|
|
|
func testWithAddress(t *testing.T, remoteAddress string, burst int, handler http.Handler) {
|
|
// create HTTP request
|
|
req, err := http.NewRequest("GET", "", nil)
|
|
require.NoError(t, err)
|
|
req.RemoteAddr = remoteAddress
|
|
|
|
// expect burst number of successes
|
|
for x := 0; x < burst; x++ {
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
assert.Equal(t, rr.Code, http.StatusOK, remoteAddress)
|
|
}
|
|
|
|
// then expect failure
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
assert.Equal(t, rr.Code, http.StatusTooManyRequests, remoteAddress)
|
|
}
|