66 lines
2.0 KiB
Go
66 lines
2.0 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package web_test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/spf13/pflag"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
|
||
|
"storj.io/common/testcontext"
|
||
|
"storj.io/private/cfgstruct"
|
||
|
"storj.io/storj/private/web"
|
||
|
)
|
||
|
|
||
|
func TestNewIPRateLimiter(t *testing.T) {
|
||
|
//create a rate limiter with defaults except NumLimits = 2
|
||
|
config := web.IPRateLimiterConfig{}
|
||
|
cfgstruct.Bind(&pflag.FlagSet{}, &config, cfgstruct.UseDevDefaults())
|
||
|
config.NumLimits = 2
|
||
|
rateLimiter := web.NewIPRateLimiter(config)
|
||
|
//run ratelimiter cleanup until end of test
|
||
|
ctx := testcontext.New(t)
|
||
|
defer ctx.Cleanup()
|
||
|
ctx2, cancel := context.WithCancel(ctx)
|
||
|
defer cancel()
|
||
|
ctx.Go(func() error {
|
||
|
rateLimiter.Run(ctx2)
|
||
|
return nil
|
||
|
})
|
||
|
//make the default HTTP handler return StatusOK
|
||
|
handler := rateLimiter.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
w.WriteHeader(http.StatusOK)
|
||
|
}))
|
||
|
//expect burst number of successes
|
||
|
testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
|
||
|
//expect similar results for a different IP
|
||
|
testWithAddress(t, "127.0.0.1:5000", rateLimiter.Burst(), handler)
|
||
|
//expect similar results for a different IP
|
||
|
testWithAddress(t, "127.0.0.100:5000", rateLimiter.Burst(), handler)
|
||
|
//expect original IP to work again because numLimits == 2
|
||
|
testWithAddress(t, "192.168.1.1:5000", rateLimiter.Burst(), handler)
|
||
|
}
|
||
|
|
||
|
func testWithAddress(t *testing.T, remoteAddress string, burst int, handler http.Handler) {
|
||
|
//create HTTP request
|
||
|
req, err := http.NewRequest("GET", "", nil)
|
||
|
require.NoError(t, err)
|
||
|
req.RemoteAddr = remoteAddress
|
||
|
//expect burst number of successes
|
||
|
for x := 0; x < burst; x++ {
|
||
|
rr := httptest.NewRecorder()
|
||
|
handler.ServeHTTP(rr, req)
|
||
|
assert.Equal(t, rr.Code, http.StatusOK, remoteAddress)
|
||
|
}
|
||
|
//then expect failure
|
||
|
rr := httptest.NewRecorder()
|
||
|
handler.ServeHTTP(rr, req)
|
||
|
assert.Equal(t, rr.Code, http.StatusTooManyRequests, remoteAddress)
|
||
|
}
|