Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve SAML error messages #8248

Merged
merged 12 commits into from
Sep 14, 2020
140 changes: 89 additions & 51 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import saml2.response
from saml2.client import Saml2Client

from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self, hs: "synapse.server.HomeServer"):
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template

# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
Expand All @@ -84,6 +86,25 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)

def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
clokep marked this conversation as resolved.
Show resolved Hide resolved

This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
clokep marked this conversation as resolved.
Show resolved Hide resolved

Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)

def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
Expand Down Expand Up @@ -134,13 +155,68 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
# the dict.
self.expire_sessions()

try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict,
)
except saml2.response.UnsolicitedResponse as e:
# the pysaml2 library helpfully logs an ERROR here, but neglects to log
# the session ID. I don't really want to put the full text of the exception
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
self._render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
self._render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
)
return

if saml2_auth.not_signed:
self._render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return

logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
count = 0
for part in chunk_seq(str(assertion), 10000):
logger.info(
"SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
)
count += 1

logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)

current_session = self._outstanding_requests_dict.pop(
saml2_auth.in_response_to, None
)

for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)

user_id, current_session = await self._map_saml_response_to_user(
# Call the mapper to register/login the user
user_id = await self._map_saml_response_to_user(
resp_bytes, relay_state, user_agent, ip_address
)

Expand All @@ -155,66 +231,28 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:

async def _map_saml_response_to_user(
self,
resp_bytes: str,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
) -> str:
"""
Given a sample response, retrieve the cached session and user for it.
Given a SAML response, retrieve the user ID for it and possibly register the user.

Args:
resp_bytes: The SAML response.
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.

Returns:
Tuple of the user ID and SAML session associated with this response.
The user ID associated with this response.

Raises:
SynapseError if there was a problem with the response.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict,
)
except saml2.response.UnsolicitedResponse as e:
# the pysaml2 library helpfully logs an ERROR here, but neglects to log
# the session ID. I don't really want to put the full text of the exception
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))

if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed.")

logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
count = 0
for part in chunk_seq(str(assertion), 10000):
logger.info(
"SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
)
count += 1

logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)

current_session = self._outstanding_requests_dict.pop(
saml2_auth.in_response_to, None
)

for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)

remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
Expand All @@ -235,7 +273,7 @@ async def _map_saml_response_to_user(
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id, current_session
return registered_user_id

# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
Expand All @@ -260,7 +298,7 @@ async def _map_saml_response_to_user(
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id

# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
Expand Down Expand Up @@ -310,7 +348,7 @@ async def _map_saml_response_to_user(
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id

def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
Expand All @@ -323,19 +361,19 @@ def expire_sessions(self):
del self._outstanding_requests_dict[reqid]


def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
return True

logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
return False


DOT_REPLACE_PATTERN = re.compile(
Expand Down