diff --git a/cmd/uplink/cmd/access.go b/cmd/uplink/cmd/access.go index 484f0d6e5..22d346804 100644 --- a/cmd/uplink/cmd/access.go +++ b/cmd/uplink/cmd/access.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "net/http" + "time" "github.com/btcsuite/btcutil/base58" "github.com/spf13/cobra" @@ -20,6 +21,8 @@ import ( "storj.io/uplink" ) +const defaultAccessRegisterTimeout = 15 * time.Second + type registerConfig struct { AuthService string `help:"the address to the service you wish to register your access with" default:"" basic-help:"true"` Public bool `help:"if the access should be public" default:"false" basic-help:"true"` @@ -137,7 +140,7 @@ func accessRegister(cmd *cobra.Command, args []string) (err error) { return errs.New("no access specified: %w", err) } - accessKey, secretKey, endpoint, err := RegisterAccess(access, registerCfg.AuthService, registerCfg.Public) + accessKey, secretKey, endpoint, err := RegisterAccess(access, registerCfg.AuthService, registerCfg.Public, defaultAccessRegisterTimeout) if err != nil { return err } @@ -181,7 +184,7 @@ func getAccessFromArgZeroOrConfig(config AccessConfig, args []string) (access *u } // RegisterAccess registers an access grant with a Gateway Authorization Service. -func RegisterAccess(access *uplink.Access, authService string, public bool) (accessKey, secretKey, endpoint string, err error) { +func RegisterAccess(access *uplink.Access, authService string, public bool, timeout time.Duration) (accessKey, secretKey, endpoint string, err error) { if authService == "" { return "", "", "", errs.New("no auth service address provided") } @@ -197,7 +200,11 @@ func RegisterAccess(access *uplink.Access, authService string, public bool) (acc return accessKey, "", "", errs.Wrap(err) } - resp, err := http.Post(fmt.Sprintf("%s/v1/access", authService), "application/json", bytes.NewReader(postData)) + client := &http.Client{ + Timeout: timeout, + } + + resp, err := client.Post(fmt.Sprintf("%s/v1/access", authService), "application/json", bytes.NewReader(postData)) if err != nil { return "", "", "", err } diff --git a/cmd/uplink/cmd/access_test.go b/cmd/uplink/cmd/access_test.go index f553cdad7..2c3d96f1c 100644 --- a/cmd/uplink/cmd/access_test.go +++ b/cmd/uplink/cmd/access_test.go @@ -8,31 +8,49 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "storj.io/common/testcontext" "storj.io/storj/cmd/uplink/cmd" - "storj.io/storj/private/testplanet" + "storj.io/uplink" ) +const testAccess = "12edqrJX1V243n5fWtUrwpMQXL8gKdY2wbyqRPSG3rsA1tzmZiQjtCyF896egifN2C2qdY6g5S1t6e8iDhMUon9Pb7HdecBFheAcvmN8652mqu8hRx5zcTUaRTWfFCKS2S6DHmTeqPUHJLEp6cJGXNHcdqegcKfeahVZGP4rTagHvFGEraXjYRJ3knAcWDGW6BxACqogEWez6r274JiUBfs4yRSbRNRqUEURd28CwDXMSHLRKKA7TEDKEdQ" + func TestRegisterAccess(t *testing.T) { - testplanet.Run(t, testplanet.Config{ - SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1, - }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { - // mock the auth service - ts := httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, `{"access_key_id":"1", "secret_key":"2", "endpoint":"3"}`) - })) - defer ts.Close() - // make sure we get back things - access := planet.Uplinks[0].Access[planet.Satellites[0].ID()] - accessKey, secretKey, endpoint, err := cmd.RegisterAccess(access, ts.URL, true) - require.NoError(t, err) - assert.Equal(t, "1", accessKey) - assert.Equal(t, "2", secretKey) - assert.Equal(t, "3", endpoint) - }) + // mock the auth service + ts := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"access_key_id":"1", "secret_key":"2", "endpoint":"3"}`) + })) + defer ts.Close() + // make sure we get back things + access, err := uplink.ParseAccess(testAccess) + require.NoError(t, err) + accessKey, secretKey, endpoint, err := cmd.RegisterAccess(access, ts.URL, true, 15*time.Second) + require.NoError(t, err) + assert.Equal(t, "1", accessKey) + assert.Equal(t, "2", secretKey) + assert.Equal(t, "3", endpoint) +} + +func TestRegisterAccessTimeout(t *testing.T) { + // mock the auth service + ch := make(chan struct{}) + ts := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-ch + })) + defer ts.Close() + // make sure we get back things + access, err := uplink.ParseAccess(testAccess) + require.NoError(t, err) + accessKey, secretKey, endpoint, err := cmd.RegisterAccess(access, ts.URL, true, 10*time.Millisecond) + require.Error(t, err) + assert.Equal(t, "", accessKey) + assert.Equal(t, "", secretKey) + assert.Equal(t, "", endpoint) + close(ch) } diff --git a/cmd/uplink/cmd/share.go b/cmd/uplink/cmd/share.go index 97cd2f2b8..a99ae6d31 100644 --- a/cmd/uplink/cmd/share.go +++ b/cmd/uplink/cmd/share.go @@ -70,7 +70,7 @@ func shareMain(cmd *cobra.Command, args []string) (err error) { if shareCfg.Register || shareCfg.URL || shareCfg.DNS != "" { isPublic := (shareCfg.Public || shareCfg.URL || shareCfg.DNS != "") - accessKey, _, _, err = RegisterAccess(newAccess, shareCfg.AuthService, isPublic) + accessKey, _, _, err = RegisterAccess(newAccess, shareCfg.AuthService, isPublic, defaultAccessRegisterTimeout) if err != nil { return err }