satellite/console: Implement MFA backend
Added MFA passcode and recovery code field for token requests. Added endpoints for MFA-related activity: enabling MFA, disabling MFA, generating a new MFA secret key, and generating new MFA recovery codes. Change-Id: Ia1443f05d3a2fecaa7f170f56d73c7a4e9b69ad5
This commit is contained in:
parent
420d2f6275
commit
dae6ed7d03
1
go.mod
1
go.mod
@ -29,6 +29,7 @@ require (
|
|||||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||||
github.com/nsf/termbox-go v0.0.0-20200418040025-38ba6e5628f1
|
github.com/nsf/termbox-go v0.0.0-20200418040025-38ba6e5628f1
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pquerna/otp v1.3.0
|
||||||
github.com/segmentio/backo-go v0.0.0-20200129164019-23eae7c10bd3 // indirect
|
github.com/segmentio/backo-go v0.0.0-20200129164019-23eae7c10bd3 // indirect
|
||||||
github.com/shopspring/decimal v1.2.0
|
github.com/shopspring/decimal v1.2.0
|
||||||
github.com/spacemonkeygo/monkit/v3 v3.0.14
|
github.com/spacemonkeygo/monkit/v3 v3.0.14
|
||||||
|
4
go.sum
4
go.sum
@ -53,6 +53,8 @@ github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdn
|
|||||||
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
|
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
|
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
|
||||||
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
|
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
|
||||||
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
|
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
|
||||||
@ -400,6 +402,8 @@ github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZ
|
|||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
|
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
|
||||||
|
github.com/pquerna/otp v1.3.0 h1:oJV/SkzR33anKXwQU3Of42rL4wbrffP4uvUf1SvS5Xs=
|
||||||
|
github.com/pquerna/otp v1.3.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||||
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||||
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
||||||
|
@ -10,7 +10,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -252,7 +254,15 @@ func (system *Satellite) AuthenticatedContext(ctx context.Context, userID uuid.U
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we are using full name as a password
|
// we are using full name as a password
|
||||||
token, err := system.API.Console.Service.Token(ctx, user.Email, user.FullName)
|
request := console.AuthUser{Email: user.Email, Password: user.FullName}
|
||||||
|
if user.MFAEnabled {
|
||||||
|
code, err := totp.GenerateCode(user.MFASecretKey, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request.MFAPasscode = code
|
||||||
|
}
|
||||||
|
token, err := system.API.Console.Service.Token(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ func Test_DeleteAPIKeyByNameAndProjectID(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// we are using full name as a password
|
// we are using full name as a password
|
||||||
token, err := sat.API.Console.Service.Token(ctx, user.Email, user.FullName)
|
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
@ -68,20 +69,18 @@ func (a *Auth) Token(w http.ResponseWriter, r *http.Request) {
|
|||||||
var err error
|
var err error
|
||||||
defer mon.Task()(&ctx)(&err)
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
var tokenRequest struct {
|
tokenRequest := console.AuthUser{}
|
||||||
Email string `json:"email"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.NewDecoder(r.Body).Decode(&tokenRequest)
|
err = json.NewDecoder(r.Body).Decode(&tokenRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.serveJSONError(w, err)
|
a.serveJSONError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := a.service.Token(ctx, tokenRequest.Email, tokenRequest.Password)
|
token, err := a.service.Token(ctx, tokenRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Info("Error authenticating token request", zap.String("email", tokenRequest.Email), zap.Error(ErrAuthAPI.Wrap(err)))
|
if !console.ErrMFAPasscodeRequired.Has(err) {
|
||||||
|
a.log.Info("Error authenticating token request", zap.String("email", tokenRequest.Email), zap.Error(ErrAuthAPI.Wrap(err)))
|
||||||
|
}
|
||||||
a.serveJSONError(w, err)
|
a.serveJSONError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -465,6 +464,86 @@ func (a *Auth) ResendEmail(w http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnableUserMFA enables multi-factor authentication for the user.
|
||||||
|
func (a *Auth) EnableUserMFA(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
var err error
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
var passcode string
|
||||||
|
err = json.NewDecoder(r.Body).Decode(&passcode)
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.service.EnableUserMFA(ctx, passcode, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableUserMFA disables multi-factor authentication for the user.
|
||||||
|
func (a *Auth) DisableUserMFA(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
var err error
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
var passcode string
|
||||||
|
err = json.NewDecoder(r.Body).Decode(&passcode)
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.service.DisableUserMFA(ctx, passcode, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateMFASecretKey creates a new TOTP secret key for the user.
|
||||||
|
func (a *Auth) GenerateMFASecretKey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
var err error
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
key, err := a.service.ResetMFASecretKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(key)
|
||||||
|
if err != nil {
|
||||||
|
a.log.Error("could not encode MFA secret key", zap.Error(ErrAuthAPI.Wrap(err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateMFARecoveryCodes creates a new set of MFA recovery codes for the user.
|
||||||
|
func (a *Auth) GenerateMFARecoveryCodes(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
var err error
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
codes, err := a.service.ResetMFARecoveryCodes(ctx)
|
||||||
|
if err != nil {
|
||||||
|
a.serveJSONError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(codes)
|
||||||
|
if err != nil {
|
||||||
|
a.log.Error("could not encode MFA recovery codes", zap.Error(ErrAuthAPI.Wrap(err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// serveJSONError writes JSON error to response output stream.
|
// serveJSONError writes JSON error to response output stream.
|
||||||
func (a *Auth) serveJSONError(w http.ResponseWriter, err error) {
|
func (a *Auth) serveJSONError(w http.ResponseWriter, err error) {
|
||||||
status := a.getStatusCode(err)
|
status := a.getStatusCode(err)
|
||||||
@ -482,6 +561,8 @@ func (a *Auth) getStatusCode(err error) int {
|
|||||||
return http.StatusConflict
|
return http.StatusConflict
|
||||||
case errors.Is(err, errNotImplemented):
|
case errors.Is(err, errNotImplemented):
|
||||||
return http.StatusNotImplemented
|
return http.StatusNotImplemented
|
||||||
|
case console.ErrMFAPasscodeRequired.Has(err):
|
||||||
|
return http.StatusContinue
|
||||||
default:
|
default:
|
||||||
return http.StatusInternalServerError
|
return http.StatusInternalServerError
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/quick"
|
"testing/quick"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -27,6 +28,7 @@ import (
|
|||||||
"storj.io/common/uuid"
|
"storj.io/common/uuid"
|
||||||
"storj.io/storj/private/testplanet"
|
"storj.io/storj/private/testplanet"
|
||||||
"storj.io/storj/satellite"
|
"storj.io/storj/satellite"
|
||||||
|
"storj.io/storj/satellite/console"
|
||||||
"storj.io/storj/satellite/console/consoleweb/consoleapi"
|
"storj.io/storj/satellite/console/consoleweb/consoleapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,3 +239,156 @@ returned response:
|
|||||||
`, cerr.Count, cerr.In, cerr.Out1[0], cerr.Out1[1], cerr.Out2[0], cerr.Out2[1])
|
`, cerr.Count, cerr.In, cerr.Out1[0], cerr.Out1[1], cerr.Out2[0], cerr.Out2[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMFAEndpoints(t *testing.T) {
|
||||||
|
testplanet.Run(t, testplanet.Config{
|
||||||
|
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||||
|
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||||
|
sat := planet.Satellites[0]
|
||||||
|
|
||||||
|
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||||
|
FullName: "MFA Test User",
|
||||||
|
Email: "mfauser@mail.test",
|
||||||
|
}, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
|
||||||
|
doRequest := func(urlSuffix string, body interface{}) *http.Response {
|
||||||
|
url := "http://" + sat.API.Console.Listener.Addr().String() + "/api/v0/auth/mfa" + urlSuffix
|
||||||
|
var buf io.Reader
|
||||||
|
|
||||||
|
if body != nil {
|
||||||
|
bodyBytes, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
buf = bytes.NewBuffer(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: "_tokenKey",
|
||||||
|
Path: "/",
|
||||||
|
Value: token,
|
||||||
|
Expires: time.Now().AddDate(0, 0, 1),
|
||||||
|
})
|
||||||
|
|
||||||
|
if body != nil {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect failure because MFA is not enabled.
|
||||||
|
result := doRequest("/generate-recovery-codes", "")
|
||||||
|
require.Equal(t, http.StatusUnauthorized, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect failure due to not having generated a secret key.
|
||||||
|
result = doRequest("/enable", "123456")
|
||||||
|
require.Equal(t, http.StatusBadRequest, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect success when generating a secret key.
|
||||||
|
result = doRequest("/generate-secret-key", "")
|
||||||
|
require.Equal(t, http.StatusOK, result.StatusCode)
|
||||||
|
|
||||||
|
var key string
|
||||||
|
err = json.NewDecoder(result.Body).Decode(&key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect failure due to prodiving empty passcode.
|
||||||
|
result = doRequest("/enable", "")
|
||||||
|
require.Equal(t, http.StatusBadRequest, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect failure due to providing invalid passcode.
|
||||||
|
badCode, err := console.NewMFAPasscode(key, time.Now().Add(time.Hour))
|
||||||
|
require.NoError(t, err)
|
||||||
|
result = doRequest("/enable", badCode)
|
||||||
|
require.Equal(t, http.StatusBadRequest, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect success when providing valid passcode.
|
||||||
|
goodCode, err := console.NewMFAPasscode(key, time.Now())
|
||||||
|
require.NoError(t, err)
|
||||||
|
result = doRequest("/enable", goodCode)
|
||||||
|
require.Equal(t, http.StatusOK, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect 10 recovery codes to be generated.
|
||||||
|
result = doRequest("/generate-recovery-codes", "")
|
||||||
|
require.Equal(t, http.StatusOK, result.StatusCode)
|
||||||
|
|
||||||
|
var codes []string
|
||||||
|
err = json.NewDecoder(result.Body).Decode(&codes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, codes, console.MFARecoveryCodeCount)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect no token due to missing passcode.
|
||||||
|
newToken, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||||
|
require.True(t, console.ErrMFAPasscodeRequired.Has(err))
|
||||||
|
require.Empty(t, newToken)
|
||||||
|
|
||||||
|
// Expect token when providing valid passcode.
|
||||||
|
newToken, err = sat.API.Console.Service.Token(ctx, console.AuthUser{
|
||||||
|
Email: user.Email,
|
||||||
|
Password: user.FullName,
|
||||||
|
MFAPasscode: goodCode,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, newToken)
|
||||||
|
|
||||||
|
// Expect no token when providing invalid recovery code.
|
||||||
|
newToken, err = sat.API.Console.Service.Token(ctx, console.AuthUser{
|
||||||
|
Email: user.Email,
|
||||||
|
Password: user.FullName,
|
||||||
|
MFARecoveryCode: "BADCODE",
|
||||||
|
})
|
||||||
|
require.True(t, console.ErrUnauthorized.Has(err))
|
||||||
|
require.Empty(t, newToken)
|
||||||
|
|
||||||
|
for _, code := range codes {
|
||||||
|
opts := console.AuthUser{
|
||||||
|
Email: user.Email,
|
||||||
|
Password: user.FullName,
|
||||||
|
MFARecoveryCode: code,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect token when providing valid recovery code.
|
||||||
|
newToken, err = sat.API.Console.Service.Token(ctx, opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, newToken)
|
||||||
|
|
||||||
|
// Expect error when providing expired recovery code.
|
||||||
|
newToken, err = sat.API.Console.Service.Token(ctx, opts)
|
||||||
|
require.True(t, console.ErrUnauthorized.Has(err))
|
||||||
|
require.Empty(t, newToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect failure due to disabling MFA with no passcode.
|
||||||
|
result = doRequest("/disable", "")
|
||||||
|
require.Equal(t, http.StatusBadRequest, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect failure due to disabling MFA with invalid passcode.
|
||||||
|
result = doRequest("/disable", badCode)
|
||||||
|
require.Equal(t, http.StatusBadRequest, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
|
||||||
|
// Expect success when disabling MFA with valid passcode.
|
||||||
|
result = doRequest("/disable", goodCode)
|
||||||
|
require.Equal(t, http.StatusOK, result.StatusCode)
|
||||||
|
require.NoError(t, result.Body.Close())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -64,7 +64,7 @@ func Test_AllBucketNames(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// we are using full name as a password
|
// we are using full name as a password
|
||||||
token, err := sat.API.Console.Service.Token(ctx, user.Email, user.FullName)
|
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
@ -67,7 +67,7 @@ func Test_TotalUsageLimits(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// we are using full name as a password
|
// we are using full name as a password
|
||||||
token, err := sat.API.Console.Service.Token(ctx, user.Email, user.FullName)
|
token, err := sat.API.Console.Service.Token(ctx, console.AuthUser{Email: user.Email, Password: user.FullName})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
@ -150,7 +150,7 @@ func TestGraphqlMutation(t *testing.T) {
|
|||||||
err = service.ActivateAccount(ctx, activationToken)
|
err = service.ActivateAccount(ctx, activationToken)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
token, err := service.Token(ctx, createUser.Email, createUser.Password)
|
token, err := service.Token(ctx, console.AuthUser{Email: createUser.Email, Password: createUser.Password})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
||||||
@ -176,7 +176,7 @@ func TestGraphqlMutation(t *testing.T) {
|
|||||||
return result.Data, nil
|
return result.Data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err = service.Token(ctx, rootUser.Email, createUser.Password)
|
token, err = service.Token(ctx, console.AuthUser{Email: rootUser.Email, Password: createUser.Password})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sauth, err = service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
sauth, err = service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
||||||
|
@ -144,7 +144,7 @@ func TestGraphqlQuery(t *testing.T) {
|
|||||||
rootUser.Email = "mtest@mail.test"
|
rootUser.Email = "mtest@mail.test"
|
||||||
})
|
})
|
||||||
|
|
||||||
token, err := service.Token(ctx, createUser.Email, createUser.Password)
|
token, err := service.Token(ctx, console.AuthUser{Email: createUser.Email, Password: createUser.Password})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
sauth, err := service.Authorize(consoleauth.WithAPIKey(ctx, []byte(token)))
|
||||||
|
@ -223,6 +223,10 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, mail
|
|||||||
authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost)
|
authRouter.Handle("/account/change-email", server.withAuth(http.HandlerFunc(authController.ChangeEmail))).Methods(http.MethodPost)
|
||||||
authRouter.Handle("/account/change-password", server.withAuth(http.HandlerFunc(authController.ChangePassword))).Methods(http.MethodPost)
|
authRouter.Handle("/account/change-password", server.withAuth(http.HandlerFunc(authController.ChangePassword))).Methods(http.MethodPost)
|
||||||
authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost)
|
authRouter.Handle("/account/delete", server.withAuth(http.HandlerFunc(authController.DeleteAccount))).Methods(http.MethodPost)
|
||||||
|
authRouter.Handle("/mfa/enable", server.withAuth(http.HandlerFunc(authController.EnableUserMFA))).Methods(http.MethodPost)
|
||||||
|
authRouter.Handle("/mfa/disable", server.withAuth(http.HandlerFunc(authController.DisableUserMFA))).Methods(http.MethodPost)
|
||||||
|
authRouter.Handle("/mfa/generate-secret-key", server.withAuth(http.HandlerFunc(authController.GenerateMFASecretKey))).Methods(http.MethodPost)
|
||||||
|
authRouter.Handle("/mfa/generate-recovery-codes", server.withAuth(http.HandlerFunc(authController.GenerateMFARecoveryCodes))).Methods(http.MethodPost)
|
||||||
authRouter.HandleFunc("/logout", authController.Logout).Methods(http.MethodPost)
|
authRouter.HandleFunc("/logout", authController.Logout).Methods(http.MethodPost)
|
||||||
authRouter.Handle("/token", server.rateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
|
authRouter.Handle("/token", server.rateLimiter.Limit(http.HandlerFunc(authController.Token))).Methods(http.MethodPost)
|
||||||
authRouter.Handle("/register", server.rateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost)
|
authRouter.Handle("/register", server.rateLimiter.Limit(http.HandlerFunc(authController.Register))).Methods(http.MethodPost)
|
||||||
|
201
satellite/console/mfa.go
Normal file
201
satellite/console/mfa.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
// Copyright (C) 2021 Storj Labs, Inc.
|
||||||
|
// See LICENSE for copying information.
|
||||||
|
|
||||||
|
package console
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pquerna/otp"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
|
"github.com/zeebo/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MFARecoveryCodeCount specifies how many MFA recovery codes to generate.
|
||||||
|
MFARecoveryCodeCount = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error messages.
|
||||||
|
const (
|
||||||
|
mfaPasscodeInvalidErrMsg = "The MFA passcode is not valid or has expired"
|
||||||
|
mfaPasscodeRequiredErrMsg = "A MFA passcode or recovery code is required"
|
||||||
|
mfaRecoveryInvalidErrMsg = "The MFA recovery code is not valid or has been previously used"
|
||||||
|
mfaRecoveryGenerationErrMsg = "MFA recovery codes cannot be generated while MFA is disabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrMFAPasscodeRequired is error type that occurs when a token request is incomplete
|
||||||
|
// due to missing MFA passcode and recovery code.
|
||||||
|
ErrMFAPasscodeRequired = errs.Class("MFA passcode required")
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewMFAValidationOpts returns the options used to validate TOTP passcodes.
|
||||||
|
// These settings are also used to generate MFA secret keys for use in testing.
|
||||||
|
func NewMFAValidationOpts() totp.ValidateOpts {
|
||||||
|
return totp.ValidateOpts{
|
||||||
|
Period: 30,
|
||||||
|
Skew: 1,
|
||||||
|
Digits: 6,
|
||||||
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateMFAPasscode returns whether the TOTP passcode is valid for the secret key at the given time.
|
||||||
|
func ValidateMFAPasscode(passcode string, secretKey string, t time.Time) (bool, error) {
|
||||||
|
valid, err := totp.ValidateCustom(passcode, secretKey, t, NewMFAValidationOpts())
|
||||||
|
return valid, Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMFAPasscode derives a TOTP passcode from a secret key using a timestamp.
|
||||||
|
func NewMFAPasscode(secretKey string, t time.Time) (string, error) {
|
||||||
|
code, err := totp.GenerateCodeCustom(secretKey, t, NewMFAValidationOpts())
|
||||||
|
return code, Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMFASecretKey generates a new TOTP secret key.
|
||||||
|
func NewMFASecretKey() (string, error) {
|
||||||
|
opts := NewMFAValidationOpts()
|
||||||
|
key, err := totp.Generate(totp.GenerateOpts{
|
||||||
|
Issuer: " ",
|
||||||
|
AccountName: " ",
|
||||||
|
Period: opts.Period,
|
||||||
|
Digits: otp.DigitsSix,
|
||||||
|
Algorithm: opts.Algorithm,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", Error.Wrap(err)
|
||||||
|
}
|
||||||
|
return key.Secret(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableUserMFA enables multi-factor authentication for the user if the given secret key and password are valid.
|
||||||
|
func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Time) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
auth, err := s.getAuthAndAuditLog(ctx, "enable MFA")
|
||||||
|
if err != nil {
|
||||||
|
return Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t)
|
||||||
|
if err != nil {
|
||||||
|
return ErrValidation.Wrap(err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return ErrValidation.New(mfaPasscodeInvalidErrMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.User.MFAEnabled = true
|
||||||
|
err = s.store.Users().Update(ctx, &auth.User)
|
||||||
|
if err != nil {
|
||||||
|
return Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableUserMFA disables multi-factor authentication for the user if the given secret key and password are valid.
|
||||||
|
func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Time) (err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
auth, err := s.getAuthAndAuditLog(ctx, "disable MFA")
|
||||||
|
if err != nil {
|
||||||
|
return Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t)
|
||||||
|
if err != nil {
|
||||||
|
return ErrValidation.Wrap(err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return ErrValidation.New(mfaPasscodeInvalidErrMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.User.MFAEnabled = false
|
||||||
|
auth.User.MFASecretKey = ""
|
||||||
|
auth.User.MFARecoveryCodes = nil
|
||||||
|
err = s.store.Users().Update(ctx, &auth.User)
|
||||||
|
if err != nil {
|
||||||
|
return Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMFARecoveryCode returns a randomly generated MFA recovery code.
|
||||||
|
// Recovery codes are uppercase and alphanumeric. They are of the form XXXX-XXXX-XXXX.
|
||||||
|
func NewMFARecoveryCode() (string, error) {
|
||||||
|
const chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
b := make([]byte, 14)
|
||||||
|
max := big.NewInt(int64(len(chars)))
|
||||||
|
for i := 0; i < 14; i++ {
|
||||||
|
if (i+1)%5 == 0 {
|
||||||
|
b[i] = '-'
|
||||||
|
} else {
|
||||||
|
num, err := rand.Int(rand.Reader, max)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
b[i] = chars[num.Int64()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetMFASecretKey creates a new TOTP secret key for the user.
|
||||||
|
func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
auth, err := s.getAuthAndAuditLog(ctx, "reset MFA secret key")
|
||||||
|
if err != nil {
|
||||||
|
return "", Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err = NewMFASecretKey()
|
||||||
|
if err != nil {
|
||||||
|
return "", Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.User.MFASecretKey = key
|
||||||
|
err = s.store.Users().Update(ctx, &auth.User)
|
||||||
|
if err != nil {
|
||||||
|
return "", Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetMFARecoveryCodes creates a new set of MFA recovery codes for the user.
|
||||||
|
func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, err error) {
|
||||||
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
auth, err := s.getAuthAndAuditLog(ctx, "reset MFA recovery codes")
|
||||||
|
if err != nil {
|
||||||
|
return nil, Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auth.User.MFAEnabled {
|
||||||
|
return nil, ErrUnauthorized.New(mfaRecoveryGenerationErrMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
codes = make([]string, MFARecoveryCodeCount)
|
||||||
|
for i := 0; i < MFARecoveryCodeCount; i++ {
|
||||||
|
code, err := NewMFARecoveryCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, Error.Wrap(err)
|
||||||
|
}
|
||||||
|
codes[i] = code
|
||||||
|
}
|
||||||
|
auth.User.MFARecoveryCodes = codes
|
||||||
|
|
||||||
|
err = s.store.Users().Update(ctx, &auth.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Error.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return codes, nil
|
||||||
|
}
|
@ -798,19 +798,53 @@ func (s *Service) RevokeResetPasswordToken(ctx context.Context, resetPasswordTok
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Token authenticates User by credentials and returns auth token.
|
// Token authenticates User by credentials and returns auth token.
|
||||||
func (s *Service) Token(ctx context.Context, email, password string) (token string, err error) {
|
func (s *Service) Token(ctx context.Context, request AuthUser) (token string, err error) {
|
||||||
defer mon.Task()(&ctx)(&err)
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
user, err := s.store.Users().GetByEmail(ctx, email)
|
user, err := s.store.Users().GetByEmail(ctx, request.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrUnauthorized.New(credentialsErrMsg)
|
return "", ErrUnauthorized.New(credentialsErrMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = bcrypt.CompareHashAndPassword(user.PasswordHash, []byte(password))
|
err = bcrypt.CompareHashAndPassword(user.PasswordHash, []byte(request.Password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrUnauthorized.New(credentialsErrMsg)
|
return "", ErrUnauthorized.New(credentialsErrMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.MFAEnabled {
|
||||||
|
if request.MFARecoveryCode != "" {
|
||||||
|
found := false
|
||||||
|
codeIndex := -1
|
||||||
|
for i, code := range user.MFARecoveryCodes {
|
||||||
|
if code == request.MFARecoveryCode {
|
||||||
|
found = true
|
||||||
|
codeIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return "", ErrUnauthorized.New(mfaRecoveryInvalidErrMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
user.MFARecoveryCodes = append(user.MFARecoveryCodes[:codeIndex], user.MFARecoveryCodes[codeIndex+1:]...)
|
||||||
|
|
||||||
|
err = s.store.Users().Update(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if request.MFAPasscode != "" {
|
||||||
|
valid, err := ValidateMFAPasscode(request.MFAPasscode, user.MFASecretKey, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrUnauthorized.Wrap(err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return "", ErrUnauthorized.New(mfaPasscodeInvalidErrMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "", ErrMFAPasscodeRequired.New(mfaPasscodeRequiredErrMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
claims := consoleauth.Claims{
|
claims := consoleauth.Claims{
|
||||||
ID: user.ID,
|
ID: user.ID,
|
||||||
Expiration: time.Now().Add(tokenExpirationTime),
|
Expiration: time.Now().Add(tokenExpirationTime),
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
package console_test
|
package console_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -278,3 +280,153 @@ func TestPaidTier(t *testing.T) {
|
|||||||
require.Equal(t, usageConfig.Bandwidth.Paid, *proj2.BandwidthLimit)
|
require.Equal(t, usageConfig.Bandwidth.Paid, *proj2.BandwidthLimit)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMFA(t *testing.T) {
|
||||||
|
testplanet.Run(t, testplanet.Config{
|
||||||
|
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||||
|
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||||
|
sat := planet.Satellites[0]
|
||||||
|
service := sat.API.Console.Service
|
||||||
|
|
||||||
|
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||||
|
FullName: "MFA Test User",
|
||||||
|
Email: "mfauser@mail.test",
|
||||||
|
}, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var auth console.Authorization
|
||||||
|
var authCtx context.Context
|
||||||
|
updateAuth := func() {
|
||||||
|
authCtx, err = sat.AuthenticatedContext(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
auth, err = console.GetAuth(authCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
updateAuth()
|
||||||
|
|
||||||
|
var key string
|
||||||
|
t.Run("TestResetMFASecretKey", func(t *testing.T) {
|
||||||
|
key, err = service.ResetMFASecretKey(authCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.NotEmpty(t, auth.User.MFASecretKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestEnableUserMFABadPasscode", func(t *testing.T) {
|
||||||
|
// Expect MFA-enabling attempt to be rejected when providing stale passcode.
|
||||||
|
badCode, err := console.NewMFAPasscode(key, time.Time{}.Add(time.Hour))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = service.EnableUserMFA(authCtx, badCode, time.Time{})
|
||||||
|
require.True(t, console.ErrValidation.Has(err))
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
_, err = service.ResetMFARecoveryCodes(authCtx)
|
||||||
|
require.True(t, console.ErrUnauthorized.Has(err))
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.False(t, auth.User.MFAEnabled)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestEnableUserMFAGoodPasscode", func(t *testing.T) {
|
||||||
|
// Expect MFA-enabling attempt to succeed when providing valid passcode.
|
||||||
|
goodCode, err := console.NewMFAPasscode(key, time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
err = service.EnableUserMFA(authCtx, goodCode, time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.True(t, auth.User.MFAEnabled)
|
||||||
|
require.Equal(t, auth.User.MFASecretKey, key)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestMFAGetToken", func(t *testing.T) {
|
||||||
|
request := console.AuthUser{Email: user.Email, Password: user.FullName}
|
||||||
|
|
||||||
|
// Expect no token due to lack of MFA passcode.
|
||||||
|
token, err := service.Token(ctx, request)
|
||||||
|
require.True(t, console.ErrMFAPasscodeRequired.Has(err))
|
||||||
|
require.Empty(t, token)
|
||||||
|
|
||||||
|
// Expect no token due to bad MFA passcode.
|
||||||
|
wrongCode, err := console.NewMFAPasscode(key, time.Now().Add(time.Minute))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
request.MFAPasscode = wrongCode
|
||||||
|
token, err = service.Token(ctx, request)
|
||||||
|
require.True(t, console.ErrUnauthorized.Has(err))
|
||||||
|
require.Empty(t, token)
|
||||||
|
|
||||||
|
// Expect token when providing valid passcode.
|
||||||
|
goodCode, err := console.NewMFAPasscode(key, time.Now())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
request.MFAPasscode = goodCode
|
||||||
|
token, err = service.Token(ctx, request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestMFARecoveryCodes", func(t *testing.T) {
|
||||||
|
_, err = service.ResetMFARecoveryCodes(authCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.Len(t, auth.User.MFARecoveryCodes, console.MFARecoveryCodeCount)
|
||||||
|
|
||||||
|
for _, code := range auth.User.MFARecoveryCodes {
|
||||||
|
// Ensure code is of the form XXXX-XXXX-XXXX where X is A-Z or 0-9.
|
||||||
|
require.Regexp(t, "^([A-Z0-9]{4})((-[A-Z0-9]{4})){2}$", code)
|
||||||
|
|
||||||
|
// Expect token when providing valid recovery code.
|
||||||
|
request := console.AuthUser{Email: user.Email, Password: user.FullName, MFARecoveryCode: code}
|
||||||
|
token, err := service.Token(ctx, request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
|
||||||
|
// Expect no token due to providing previously-used recovery code.
|
||||||
|
token, err = service.Token(ctx, request)
|
||||||
|
require.True(t, console.ErrUnauthorized.Has(err))
|
||||||
|
require.Empty(t, token)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = service.ResetMFARecoveryCodes(authCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestDisableUserMFABadPasscode", func(t *testing.T) {
|
||||||
|
// Expect MFA-disabling attempt to fail when providing valid passcode.
|
||||||
|
badCode, err := console.NewMFAPasscode(key, time.Time{}.Add(time.Hour))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
err = service.DisableUserMFA(authCtx, badCode, time.Time{})
|
||||||
|
require.True(t, console.ErrValidation.Has(err))
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.True(t, auth.User.MFAEnabled)
|
||||||
|
require.NotEmpty(t, auth.User.MFASecretKey)
|
||||||
|
require.NotEmpty(t, auth.User.MFARecoveryCodes)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestDisableUserMFAGoodPasscode", func(t *testing.T) {
|
||||||
|
// Expect MFA-disabling attempt to succeed when providing valid passcode.
|
||||||
|
goodCode, err := console.NewMFAPasscode(key, time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
err = service.DisableUserMFA(authCtx, goodCode, time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updateAuth()
|
||||||
|
require.False(t, auth.User.MFAEnabled)
|
||||||
|
require.Empty(t, auth.User.MFASecretKey)
|
||||||
|
require.Empty(t, auth.User.MFARecoveryCodes)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -87,6 +87,14 @@ func (user *CreateUser) IsValid() error {
|
|||||||
return errs.Combine()
|
return errs.Combine()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthUser holds info for user authentication token requests.
|
||||||
|
type AuthUser struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
MFAPasscode string `json:"mfaPasscode"`
|
||||||
|
MFARecoveryCode string `json:"mfaRecoveryCode"`
|
||||||
|
}
|
||||||
|
|
||||||
// UserStatus - is used to indicate status of the users account.
|
// UserStatus - is used to indicate status of the users account.
|
||||||
type UserStatus int
|
type UserStatus int
|
||||||
|
|
||||||
@ -125,4 +133,8 @@ type User struct {
|
|||||||
EmployeeCount string `json:"employeeCount"`
|
EmployeeCount string `json:"employeeCount"`
|
||||||
|
|
||||||
HaveSalesContact bool `json:"haveSalesContact"`
|
HaveSalesContact bool `json:"haveSalesContact"`
|
||||||
|
|
||||||
|
MFAEnabled bool `json:"mfaEnabled"`
|
||||||
|
MFASecretKey string `json:"mfaSecretKey"`
|
||||||
|
MFARecoveryCodes []string `json:"mfaRecoveryCodes"`
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ const (
|
|||||||
employeeCount = "0"
|
employeeCount = "0"
|
||||||
workingOn = "workingOn"
|
workingOn = "workingOn"
|
||||||
isProfessional = true
|
isProfessional = true
|
||||||
|
mfaSecretKey = "mfaSecretKey"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUserRepository(t *testing.T) {
|
func TestUserRepository(t *testing.T) {
|
||||||
@ -178,6 +179,9 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
|||||||
assert.Equal(t, lastName, userByEmail.ShortName)
|
assert.Equal(t, lastName, userByEmail.ShortName)
|
||||||
assert.Equal(t, user.PartnerID, userByEmail.PartnerID)
|
assert.Equal(t, user.PartnerID, userByEmail.PartnerID)
|
||||||
assert.False(t, user.PaidTier)
|
assert.False(t, user.PaidTier)
|
||||||
|
assert.False(t, user.MFAEnabled)
|
||||||
|
assert.Empty(t, user.MFASecretKey)
|
||||||
|
assert.Empty(t, user.MFARecoveryCodes)
|
||||||
if user.IsProfessional {
|
if user.IsProfessional {
|
||||||
assert.Equal(t, workingOn, userByEmail.WorkingOn)
|
assert.Equal(t, workingOn, userByEmail.WorkingOn)
|
||||||
assert.Equal(t, position, userByEmail.Position)
|
assert.Equal(t, position, userByEmail.Position)
|
||||||
@ -195,6 +199,9 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
|||||||
assert.Equal(t, name, userByID.FullName)
|
assert.Equal(t, name, userByID.FullName)
|
||||||
assert.Equal(t, lastName, userByID.ShortName)
|
assert.Equal(t, lastName, userByID.ShortName)
|
||||||
assert.Equal(t, user.PartnerID, userByID.PartnerID)
|
assert.Equal(t, user.PartnerID, userByID.PartnerID)
|
||||||
|
assert.False(t, user.MFAEnabled)
|
||||||
|
assert.Empty(t, user.MFASecretKey)
|
||||||
|
assert.Empty(t, user.MFARecoveryCodes)
|
||||||
|
|
||||||
if user.IsProfessional {
|
if user.IsProfessional {
|
||||||
assert.Equal(t, workingOn, userByID.WorkingOn)
|
assert.Equal(t, workingOn, userByID.WorkingOn)
|
||||||
@ -226,20 +233,23 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
|||||||
oldUser, err := repository.GetByEmail(ctx, email)
|
oldUser, err := repository.GetByEmail(ctx, email)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
newUser := &console.User{
|
newUserInfo := &console.User{
|
||||||
ID: oldUser.ID,
|
ID: oldUser.ID,
|
||||||
FullName: newName,
|
FullName: newName,
|
||||||
ShortName: newLastName,
|
ShortName: newLastName,
|
||||||
Email: newEmail,
|
Email: newEmail,
|
||||||
Status: console.Active,
|
Status: console.Active,
|
||||||
PaidTier: true,
|
PaidTier: true,
|
||||||
PasswordHash: []byte(newPass),
|
MFAEnabled: true,
|
||||||
|
MFASecretKey: mfaSecretKey,
|
||||||
|
MFARecoveryCodes: []string{"1", "2"},
|
||||||
|
PasswordHash: []byte(newPass),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = repository.Update(ctx, newUser)
|
err = repository.Update(ctx, newUserInfo)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
newUser, err = repository.Get(ctx, oldUser.ID)
|
newUser, err := repository.Get(ctx, oldUser.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, oldUser.ID, newUser.ID)
|
assert.Equal(t, oldUser.ID, newUser.ID)
|
||||||
assert.Equal(t, newName, newUser.FullName)
|
assert.Equal(t, newName, newUser.FullName)
|
||||||
@ -247,6 +257,9 @@ func testUsers(ctx context.Context, t *testing.T, repository console.Users, user
|
|||||||
assert.Equal(t, newEmail, newUser.Email)
|
assert.Equal(t, newEmail, newUser.Email)
|
||||||
assert.Equal(t, []byte(newPass), newUser.PasswordHash)
|
assert.Equal(t, []byte(newPass), newUser.PasswordHash)
|
||||||
assert.True(t, newUser.PaidTier)
|
assert.True(t, newUser.PaidTier)
|
||||||
|
assert.True(t, newUser.MFAEnabled)
|
||||||
|
assert.Equal(t, mfaSecretKey, newUser.MFASecretKey)
|
||||||
|
assert.Equal(t, newUserInfo.MFARecoveryCodes, newUser.MFARecoveryCodes)
|
||||||
// PartnerID should not change
|
// PartnerID should not change
|
||||||
assert.Equal(t, user.PartnerID, newUser.PartnerID)
|
assert.Equal(t, user.PartnerID, newUser.PartnerID)
|
||||||
assert.Equal(t, oldUser.CreatedAt, newUser.CreatedAt)
|
assert.Equal(t, oldUser.CreatedAt, newUser.CreatedAt)
|
||||||
|
@ -5,6 +5,7 @@ package satellitedb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
@ -99,10 +100,15 @@ func (users *users) Delete(ctx context.Context, id uuid.UUID) (err error) {
|
|||||||
func (users *users) Update(ctx context.Context, user *console.User) (err error) {
|
func (users *users) Update(ctx context.Context, user *console.User) (err error) {
|
||||||
defer mon.Task()(&ctx)(&err)
|
defer mon.Task()(&ctx)(&err)
|
||||||
|
|
||||||
|
updateFields, err := toUpdateUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
_, err = users.db.Update_User_By_Id(
|
_, err = users.db.Update_User_By_Id(
|
||||||
ctx,
|
ctx,
|
||||||
dbx.User_Id(user.ID[:]),
|
dbx.User_Id(user.ID[:]),
|
||||||
toUpdateUser(user),
|
*updateFields,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@ -135,7 +141,7 @@ func (users *users) GetProjectLimit(ctx context.Context, id uuid.UUID) (limit in
|
|||||||
}
|
}
|
||||||
|
|
||||||
// toUpdateUser creates dbx.User_Update_Fields with only non-empty fields as updatable.
|
// toUpdateUser creates dbx.User_Update_Fields with only non-empty fields as updatable.
|
||||||
func toUpdateUser(user *console.User) dbx.User_Update_Fields {
|
func toUpdateUser(user *console.User) (*dbx.User_Update_Fields, error) {
|
||||||
update := dbx.User_Update_Fields{
|
update := dbx.User_Update_Fields{
|
||||||
FullName: dbx.User_FullName(user.FullName),
|
FullName: dbx.User_FullName(user.FullName),
|
||||||
ShortName: dbx.User_ShortName(user.ShortName),
|
ShortName: dbx.User_ShortName(user.ShortName),
|
||||||
@ -144,14 +150,22 @@ func toUpdateUser(user *console.User) dbx.User_Update_Fields {
|
|||||||
Status: dbx.User_Status(int(user.Status)),
|
Status: dbx.User_Status(int(user.Status)),
|
||||||
ProjectLimit: dbx.User_ProjectLimit(user.ProjectLimit),
|
ProjectLimit: dbx.User_ProjectLimit(user.ProjectLimit),
|
||||||
PaidTier: dbx.User_PaidTier(user.PaidTier),
|
PaidTier: dbx.User_PaidTier(user.PaidTier),
|
||||||
|
MfaEnabled: dbx.User_MfaEnabled(user.MFAEnabled),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recoveryBytes, err := json.Marshal(user.MFARecoveryCodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
update.MfaRecoveryCodes = dbx.User_MfaRecoveryCodes(string(recoveryBytes))
|
||||||
|
update.MfaSecretKey = dbx.User_MfaSecretKey(user.MFASecretKey)
|
||||||
|
|
||||||
// extra password check to update only calculated hash from service
|
// extra password check to update only calculated hash from service
|
||||||
if len(user.PasswordHash) != 0 {
|
if len(user.PasswordHash) != 0 {
|
||||||
update.PasswordHash = dbx.User_PasswordHash(user.PasswordHash)
|
update.PasswordHash = dbx.User_PasswordHash(user.PasswordHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
return update
|
return &update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// userFromDBX is used for creating User entity from autogenerated dbx.User struct.
|
// userFromDBX is used for creating User entity from autogenerated dbx.User struct.
|
||||||
@ -166,6 +180,14 @@ func userFromDBX(ctx context.Context, user *dbx.User) (_ *console.User, err erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var recoveryCodes []string
|
||||||
|
if user.MfaRecoveryCodes != nil {
|
||||||
|
err = json.Unmarshal([]byte(*user.MfaRecoveryCodes), &recoveryCodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
result := console.User{
|
result := console.User{
|
||||||
ID: id,
|
ID: id,
|
||||||
FullName: user.FullName,
|
FullName: user.FullName,
|
||||||
@ -177,6 +199,7 @@ func userFromDBX(ctx context.Context, user *dbx.User) (_ *console.User, err erro
|
|||||||
PaidTier: user.PaidTier,
|
PaidTier: user.PaidTier,
|
||||||
IsProfessional: user.IsProfessional,
|
IsProfessional: user.IsProfessional,
|
||||||
HaveSalesContact: user.HaveSalesContact,
|
HaveSalesContact: user.HaveSalesContact,
|
||||||
|
MFAEnabled: user.MfaEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.PartnerId != nil {
|
if user.PartnerId != nil {
|
||||||
@ -206,6 +229,14 @@ func userFromDBX(ctx context.Context, user *dbx.User) (_ *console.User, err erro
|
|||||||
result.EmployeeCount = *user.EmployeeCount
|
result.EmployeeCount = *user.EmployeeCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.MfaSecretKey != nil {
|
||||||
|
result.MFASecretKey = *user.MfaSecretKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.MfaRecoveryCodes != nil {
|
||||||
|
result.MFARecoveryCodes = recoveryCodes
|
||||||
|
}
|
||||||
|
|
||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user