Skip to content

Commit

Permalink
Move tests to mockoidc_test package
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Meves committed Feb 5, 2021
1 parent 2ab0940 commit ba3bff8
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 216 deletions.
49 changes: 24 additions & 25 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ import (
)

const (
IssuerBase = "/oidc"
AuthorizeEndpoint = "/oidc/authorize"
TokenEndpoint = "/oidc/token"
UserinfoEndpoint = "/oidc/userinfo"
JWKSEndpoint = "/oidc/.well-known/jwks.json"
DiscoveryEndpoint = "/oidc/.well-known/openid-configuration"

issuerBase = "/oidc"

invalidRequest = "invalid_request"
invalidClient = "invalid_client"
invalidGrant = "invalid_grant"
unsupportedGrantType = "unsupported_grant_type"
InvalidRequest = "invalid_request"
InvalidClient = "invalid_client"
InvalidGrant = "invalid_grant"
UnsupportedGrantType = "unsupported_grant_type"
//invalidScope = "invalid_scope"
//unauthorizedClient = "unauthorized_client"

Expand Down Expand Up @@ -85,12 +84,12 @@ func (m *MockOIDC) Authorize(rw http.ResponseWriter, req *http.Request) {
}

validClient := assertEqual("client_id", m.ClientID,
invalidClient, "Invalid client id", rw, req)
InvalidClient, "Invalid client id", rw, req)
if !validClient {
return
}
validType := assertEqual("response_type", "code",
unsupportedGrantType, "Invalid response type", rw, req)
UnsupportedGrantType, "Invalid response type", rw, req)
if !validType {
return
}
Expand Down Expand Up @@ -163,7 +162,7 @@ func (m *MockOIDC) Token(rw http.ResponseWriter, req *http.Request) {
}
session = s
default:
errorResponse(rw, invalidRequest, "Invalid grant_type", http.StatusBadRequest)
errorResponse(rw, InvalidRequest, "Invalid grant_type", http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -196,11 +195,11 @@ func (m *MockOIDC) sharedParamsValidator(rw http.ResponseWriter, req *http.Reque
if !ok {
return false
}
equal := assertEqual("client_id", m.ClientID, invalidClient, "Invalid client id", rw, req)
equal := assertEqual("client_id", m.ClientID, InvalidClient, "Invalid client id", rw, req)
if !equal {
return false
}
equal = assertEqual("client_secret", m.ClientSecret, invalidClient, "Invalid client secret", rw, req)
equal = assertEqual("client_secret", m.ClientSecret, InvalidClient, "Invalid client secret", rw, req)
if !equal {
return false
}
Expand All @@ -214,7 +213,7 @@ func (m *MockOIDC) accessRequestValidator(rw http.ResponseWriter, req *http.Requ
code := req.Form.Get("code")
session, err := m.SessionStore.GetSessionByID(code)
if err != nil {
errorResponse(rw, invalidGrant, fmt.Sprintf("Invalid code: %s", code),
errorResponse(rw, InvalidGrant, fmt.Sprintf("Invalid code: %s", code),
http.StatusUnauthorized)
return nil, false
}
Expand All @@ -226,13 +225,13 @@ func (m *MockOIDC) refreshRequestValidator(rw http.ResponseWriter, req *http.Req
return nil, false
}
refreshToken := req.Form.Get("refresh_token")
token, ok := m.authorizeTokenString(refreshToken, rw, req)
token, ok := m.authorizeTokenString(refreshToken, rw)
if !ok {
return nil, false
}
session, err := m.SessionStore.GetSessionByToken(token)
if err != nil {
errorResponse(rw, invalidGrant, "Invalid refresh token",
errorResponse(rw, InvalidGrant, "Invalid refresh token",
http.StatusUnauthorized)
return nil, false
}
Expand All @@ -248,7 +247,7 @@ func (m *MockOIDC) validateAccessParams(rw http.ResponseWriter, req *http.Reques
}

equal := assertEqual("grant_type", "authorization_code",
unsupportedGrantType, "Invalid grant type", rw, req)
UnsupportedGrantType, "Invalid grant type", rw, req)
if !equal {
return false
}
Expand All @@ -264,7 +263,7 @@ func (m *MockOIDC) validateRefreshParams(rw http.ResponseWriter, req *http.Reque
}

equal := assertEqual("grant_type", "refresh_token",
unsupportedGrantType, "Invalid grant type", rw, req)
UnsupportedGrantType, "Invalid grant type", rw, req)
if !equal {
return false
}
Expand Down Expand Up @@ -371,23 +370,23 @@ func (m *MockOIDC) validateBearerToken(rw http.ResponseWriter, req *http.Request
authz := req.Header.Get("Authorization")
parts := strings.Split(authz, " ")
if len(parts) < 2 || parts[0] != "Bearer" {
errorResponse(rw, invalidRequest, "Invalid authorization header",
errorResponse(rw, InvalidRequest, "Invalid authorization header",
http.StatusUnauthorized)
return nil
}

token, ok := m.authorizeTokenString(parts[1], rw, req)
token, ok := m.authorizeTokenString(parts[1], rw)
if ok {
return token
}

return nil
}

func (m *MockOIDC) authorizeTokenString(tokenString string, rw http.ResponseWriter, req *http.Request) (*jwt.Token, bool) {
func (m *MockOIDC) authorizeTokenString(tokenString string, rw http.ResponseWriter) (*jwt.Token, bool) {
token, err := m.Keypair.VerifyJWT(tokenString)
if err != nil {
errorResponse(rw, invalidRequest, fmt.Sprintf("Invalid token: %v", err), http.StatusUnauthorized)
errorResponse(rw, InvalidRequest, fmt.Sprintf("Invalid token: %v", err), http.StatusUnauthorized)
return nil, false
}

Expand All @@ -396,13 +395,13 @@ func (m *MockOIDC) authorizeTokenString(tokenString string, rw http.ResponseWrit
internalServerError(rw, "Unable to extract token claims")
return nil, false
}
exp := claims["exp"].(float64)
if err != nil {
internalServerError(rw, err.Error())
exp, ok := claims["exp"].(float64)
if !ok {
internalServerError(rw, "Unable to extract token expiration")
return nil, false
}
if m.Now().Unix() > int64(exp) {
errorResponse(rw, invalidRequest, "The token is expired", http.StatusUnauthorized)
errorResponse(rw, InvalidRequest, "The token is expired", http.StatusUnauthorized)
return nil, false
}
return token, true
Expand All @@ -416,7 +415,7 @@ func assertPresence(params []string, rw http.ResponseWriter, req *http.Request)

errorResponse(
rw,
invalidRequest,
InvalidRequest,
fmt.Sprintf("The request is missing the required parameter: %s", param),
http.StatusBadRequest,
)
Expand Down
Loading

0 comments on commit ba3bff8

Please sign in to comment.