Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor mfa tests #1322

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 88 additions & 80 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -25,11 +26,14 @@ import (

type MFATestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
TestUser *models.User
TestSession *models.Session
}

func TestMFA(t *testing.T) {
Expand All @@ -53,20 +57,24 @@ func (ts *MFATestSuite) SetupTest() {
f, err := models.NewFactor(u, "test_factor", models.TOTP, models.FactorStateUnverified, "secretkey")
require.NoError(ts.T(), err, "Error creating test factor model")
require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor")
// Create corresponding sessoin
// Create corresponding session
s, err := models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
s.UserID = u.ID
s.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(s), "Error saving test session")

ts.TestUser = u
ts.TestSession = s

// Generate TOTP related settings
emailValue, err := u.Email.Value()
require.NoError(ts.T(), err)
testEmail := emailValue.(string)
testDomain := strings.Split(testEmail, "@")[1]
ts.TestDomain = testDomain
ts.TestEmail = testEmail
ts.TestPassword = "password"

key, err := totp.Generate(totp.GenerateOpts{
Issuer: ts.TestDomain,
Expand All @@ -80,6 +88,12 @@ func (ts *MFATestSuite) SetupTest() {
func (ts *MFATestSuite) TestEnrollFactor() {
testFriendlyName := "bob"
alternativeFriendlyName := "john"
user, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
ts.Require().NoError(err)

token, _, err := generateAccessToken(ts.API.db, user, nil, &ts.Config.JWT)

require.NoError(ts.T(), err)
var cases = []struct {
desc string
friendlyName string
Expand Down Expand Up @@ -119,20 +133,8 @@ func (ts *MFATestSuite) TestEnrollFactor() {
}
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": c.friendlyName, "factor_type": c.factorType, "issuer": c.issuer}))
J0 marked this conversation as resolved.
Show resolved Hide resolved
user, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
ts.Require().NoError(err)

token, _, err := generateAccessToken(ts.API.db, user, nil, &ts.Config.JWT)
require.NoError(ts.T(), err)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/factors", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedCode, w.Code)
w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, user)
ts.Require().NoError(err)
Expand Down Expand Up @@ -165,17 +167,13 @@ func (ts *MFATestSuite) TestChallengeFactor() {
token, _, err := generateAccessToken(ts.API.db, u, nil, &ts.Config.JWT)
require.NoError(ts.T(), err, "Error generating access token")

var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", f.ID), &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
w := performChallengeFlow(ts, f.ID, token)
require.Equal(ts.T(), http.StatusOK, w.Code)
}

func (ts *MFATestSuite) TestMFAVerifyFactor() {
user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
J0 marked this conversation as resolved.
Show resolved Hide resolved
ts.Require().NoError(err)
cases := []struct {
desc string
validChallenge bool
Expand Down Expand Up @@ -204,8 +202,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
// Authenticate users and set secret
user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
ts.Require().NoError(err)

var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, user, models.GrantParams{})
require.NoError(ts.T(), err)
Expand Down Expand Up @@ -272,6 +269,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}

func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

cases := []struct {
desc string
isAAL2 bool
Expand All @@ -289,25 +287,20 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
},
}
for _, v := range cases {

ts.Run(v.desc, func() {
// Create User
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
s, err := models.FindSessionByUserID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
if v.isAAL2 {
s.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}
var secondarySession *models.Session

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, u)
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = u.ID
secondarySession.UserID = ts.TestUser.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

Expand All @@ -319,7 +312,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

var buffer bytes.Buffer

token, _, err := generateAccessToken(ts.API.db, u, &s.ID, &ts.Config.JWT)
token, _, err := generateAccessToken(ts.API.db, ts.TestUser, &ts.TestSession.ID, &ts.Config.JWT)
require.NoError(ts.T(), err)

w := httptest.NewRecorder()
Expand All @@ -342,17 +335,13 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}

func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
s, err := models.FindSessionByUserID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
var secondarySession *models.Session
factors, err := models.FindFactorsByUser(ts.API.db, u)
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = u.ID
secondarySession.UserID = ts.TestUser.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

Expand All @@ -361,7 +350,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {

var buffer bytes.Buffer

token, _, err := generateAccessToken(ts.API.db, u, &s.ID, &ts.Config.JWT)
token, _, err := generateAccessToken(ts.API.db, ts.TestUser, &ts.TestSession.ID, &ts.Config.JWT)
require.NoError(ts.T(), err)
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"factor_id": f.ID,
Expand All @@ -382,14 +371,16 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
email := "test1@example.com"
password := "test123"
token := signUpAndVerify(ts, email, password)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": token.RefreshToken,
"refresh_token": accessTokenResp.RefreshToken,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
Expand All @@ -408,14 +399,15 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
email := "test1@example.com"
password := "test123"
token := signUpAndVerify(ts, email, password)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": email,
"password": password,
"email": ts.TestEmail,
"password": ts.TestPassword,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
Expand All @@ -430,7 +422,7 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
ctx, err = ts.API.maybeLoadUserOrSession(ctx)
require.NoError(ts.T(), err)
require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL())
session, err := models.FindSessionByUserID(ts.API.db, token.User.ID)
session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID)
require.NoError(ts.T(), err)
require.True(ts.T(), session.IsAAL2())
}
Expand All @@ -455,43 +447,30 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
return data
}

func signUpAndVerify(ts *MFATestSuite, email, password string) (verifyResp *AccessTokenResponse) {
func performTestSignupAndVerify(ts *MFATestSuite, email, password string) *httptest.ResponseRecorder {

signUpResp := signUp(ts, email, password)
verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token)
resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token)

return verifyResp
return resp

}

func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyResp *AccessTokenResponse) {
func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder {
var buffer bytes.Buffer
w := httptest.NewRecorder()
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": "john", "factor_type": models.TOTP, "issuer": ts.TestDomain}))
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/factors/", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
factorID := enrollResp.ID

// Challenge
var challengeBuffer bytes.Buffer
x := httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(x, req)
require.Equal(ts.T(), http.StatusOK, x.Code)
challengeResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&challengeResp))
challengeID := challengeResp.ID
require.Equal(ts.T(), expectedCode, w.Code)
return w

// Verify
}
func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int) *httptest.ResponseRecorder {
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()

Expand All @@ -511,13 +490,42 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR
"challenge_id": challengeID,
"code": code,
}))
req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
require.Equal(ts.T(), http.StatusOK, y.Code)
verifyResp = &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(y.Body).Decode(&verifyResp))
return verifyResp
return y
}

func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder {
var challengeBuffer bytes.Buffer
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
return w

}

func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *httptest.ResponseRecorder {
w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, http.StatusOK)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
factorID := enrollResp.ID

// Challenge
w = performChallengeFlow(ts, factorID, token)

challengeResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp))
challengeID := challengeResp.ID

// Verify
y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK)

return y
}
Loading