From 37fb198e5fdde91dc8b0406e07f5bb0395d6847a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 22 Oct 2020 14:17:50 -0400 Subject: [PATCH 1/8] Fix some types in the CAS code. --- changelog.d/8784.misc | 1 + synapse/handlers/cas_handler.py | 18 ++++++++++++++---- synapse/handlers/register.py | 2 +- synapse/handlers/saml_handler.py | 4 ++-- 4 files changed, 18 insertions(+), 7 deletions(-) create mode 100644 changelog.d/8784.misc diff --git a/changelog.d/8784.misc b/changelog.d/8784.misc new file mode 100644 index 000000000000..7e651155502e --- /dev/null +++ b/changelog.d/8784.misc @@ -0,0 +1 @@ +Fix type hints in CAS handler. diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 048a3b3c0bba..a16d9e311b86 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -23,6 +23,9 @@ from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,10 +34,10 @@ class CasHandler: Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -205,11 +208,16 @@ async def handle_ticket( registered_user_id = await self._auth_handler.check_user_exists(user_id) if session: + # If there's a session then the user must already exist. + assert registered_user_id + await self._auth_handler.complete_sso_ui_auth( registered_user_id, session, request, ) - else: + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url + if not registered_user_id: # Pull out the user-agent and IP from the request. user_agent = request.get_user_agent("") @@ -221,6 +229,8 @@ async def handle_ticket( user_agent_ips=(user_agent, ip_address), ) + assert registered_user_id + await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 252f7007865b..cfa07eec9e38 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -152,7 +152,7 @@ async def register_user( bind_emails=[], by_admin=False, user_agent_ips=None, - ): + ) -> str: """Registers a new client on the server. Args: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5d9b555b131c..50561ed39957 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -39,7 +39,7 @@ from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class Saml2SessionData: class SamlHandler(BaseHandler): - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid From 28ad9bf62eb32458cb7c02099f8b6af67b9d2388 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 15:05:44 -0500 Subject: [PATCH 2/8] Add additional types to registration handler. --- mypy.ini | 1 + synapse/handlers/cas_handler.py | 2 +- synapse/handlers/oidc_handler.py | 2 +- synapse/handlers/register.py | 207 +++++++++++++++++-------------- synapse/handlers/saml_handler.py | 2 +- 5 files changed, 117 insertions(+), 97 deletions(-) diff --git a/mypy.ini b/mypy.ini index fc9f8d805069..0cf7c93f457e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,6 +37,7 @@ files = synapse/handlers/presence.py, synapse/handlers/profile.py, synapse/handlers/read_marker.py, + synapse/handlers/register.py, synapse/handlers/room.py, synapse/handlers/room_member.py, synapse/handlers/room_member_worker.py, diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index a16d9e311b86..2b05c702880c 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -226,7 +226,7 @@ async def handle_ticket( registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), + user_agent_ips=[(user_agent, ip_address)], ) assert registered_user_id diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4bfd8d5617be..34de9109ea9a 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -925,7 +925,7 @@ async def _map_userinfo_to_user( registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), + user_agent_ips=[(user_agent, ip_address)], ) await self.store.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cfa07eec9e38..aaf64a973772 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,10 +15,12 @@ """Contains functions for registering clients.""" import logging +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -32,16 +34,14 @@ from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -71,7 +71,10 @@ def __init__(self, hs): self.session_lifetime = hs.config.session_lifetime async def check_username( - self, localpart, guest_access_token=None, assigned_user_id=None + self, + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -140,39 +143,45 @@ async def check_username( async def register_user( self, - localpart=None, - password_hash=None, - guest_access_token=None, - make_guest=False, - admin=False, - threepid=None, - user_type=None, - default_display_name=None, - address=None, - bind_emails=[], - by_admin=False, - user_agent_ips=None, + localpart: Optional[str] = None, + password_hash: Optional[str] = None, + guest_access_token: Optional[str] = None, + make_guest: bool = False, + admin: bool = False, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + default_display_name: Optional[str] = None, + address: Optional[str] = None, + bind_emails: List[str] = [], + by_admin: bool = False, + user_agent_ips: Optional[List[Tuple[str, str]]] = None, ) -> str: """Registers a new client on the server. Args: localpart: The local part of the user ID to register. If None, one will be generated. - password_hash (str|None): The hashed password to assign to this user so they can + password_hash: The hashed password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). - user_type (str|None): type of user. One of the values from + guest_access_token: The access token used when this was a guest + account. + make_guest: True if the the new user should be guest, + false to add a regular user account. + admin: True if the user should be registered as a server admin. + threepid: The threepid used for registering, if any. + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - default_display_name (unicode|None): if set, the new user's displayname + default_display_name: if set, the new user's displayname will be set to this. Defaults to 'localpart'. - address (str|None): the IP address used to perform the registration. - bind_emails (List[str]): list of emails to bind to this account. - by_admin (bool): True if this registration is being made via the + address: the IP address used to perform the registration. + bind_emails: list of emails to bind to this account. + by_admin: True if this registration is being made via the admin api, otherwise False. - user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. Returns: - str: user_id + The registere user_id. Raises: SynapseError if there was a problem registering. """ @@ -236,8 +245,8 @@ async def register_user( else: # autogen a sequential user ID fail_count = 0 - user = None - while not user: + # This breaks on successful registration *or* errors after 10 failures. + while True: # Fail after being unable to find a suitable ID a few times if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") @@ -246,6 +255,8 @@ async def register_user( user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) + # TODO This looks incorrect since default_display_name can only + # ever be None for the first iteration. if default_display_name is None: default_display_name = localpart try: @@ -262,8 +273,6 @@ async def register_user( break except SynapseError: # if user id is taken, just generate another - user = None - user_id = None fail_count += 1 if not self.hs.config.user_consent_at_registration: @@ -295,7 +304,7 @@ async def register_user( return user_id - async def _create_and_join_rooms(self, user_id: str): + async def _create_and_join_rooms(self, user_id: str) -> None: """ Create the auto-join rooms and join or invite the user to them. @@ -379,7 +388,7 @@ async def _create_and_join_rooms(self, user_id: str): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _join_rooms(self, user_id: str): + async def _join_rooms(self, user_id: str) -> None: """ Join or invite the user to the auto-join rooms. @@ -425,6 +434,9 @@ async def _join_rooms(self, user_id: str): # Send the invite, if necessary. if requires_invite: + # If an invite is required, there must be a auto-join user ID. + assert self.hs.config.registration.auto_join_user_id + await room_member_handler.update_membership( requester=create_requester( self.hs.config.registration.auto_join_user_id, @@ -456,7 +468,7 @@ async def _join_rooms(self, user_id: str): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _auto_join_rooms(self, user_id: str): + async def _auto_join_rooms(self, user_id: str) -> None: """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -479,16 +491,16 @@ async def _auto_join_rooms(self, user_id: str): else: await self._join_rooms(user_id) - async def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id: str) -> None: """A series of registration actions that can only be carried out once consent has been granted Args: - user_id (str): The user to join + user_id: The user to join """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart: str, as_token: str) -> str: user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -513,7 +525,9 @@ async def appservice_register(self, user_localpart, as_token): ) return user_id - def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): + def check_user_id_not_appservice_exclusive( + self, user_id: str, allowed_appservice: Optional[ApplicationService] = None + ) -> None: # don't allow people to register the server notices mxid if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: @@ -537,12 +551,12 @@ def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=Non errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address): + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address Args: - address (str|None): the IP address used to perform the registration. If this is + address: the IP address used to perform the registration. If this is None, no ratelimiting will be performed. Raises: @@ -553,42 +567,42 @@ def check_registration_ratelimit(self, address): self.ratelimiter.ratelimit(address) - def register_with_store( + async def register_with_store( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - address=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + address: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Register user in the datastore. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Optional. Whether this is a guest account being upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, + make_guest: True if the the new user should be guest, false to add a regular user account. - appservice_id (str|None): The ID of the appservice registering the user. - create_profile_with_displayname (unicode|None): Optionally create a + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the registration. - shadow_banned (bool): Whether to shadow-ban the user + address: the IP address used to perform the registration. + shadow_banned: Whether to shadow-ban the user Returns: Awaitable """ if self.hs.config.worker_app: - return self._register_client( + await self._register_client( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -601,7 +615,7 @@ def register_with_store( shadow_banned=shadow_banned, ) else: - return self.store.register_user( + await self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -614,22 +628,24 @@ def register_with_store( ) async def register_device( - self, user_id, device_id, initial_display_name, is_guest=False - ): + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + ) -> Tuple[str, str]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. + initial_display_name: An optional display name for the device. + is_guest: Whether this is a guest account Returns: - tuple[str, str]: Tuple of device ID and access token + Tuple of device ID and access token """ if self.hs.config.worker_app: @@ -649,7 +665,7 @@ async def register_device( ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = await self.device_handler.check_device_registered( + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -659,20 +675,21 @@ async def register_device( ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms ) - return (device_id, access_token) + return (registered_device_id, access_token) - async def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions( + self, user_id: str, auth_result: dict, access_token: Optional[str] + ) -> None: """A user has completed registration Args: - user_id (str): The user ID that consented - auth_result (dict): The authenticated credentials of the newly - registered user. - access_token (str|None): The access token of the newly logged in - device, or None if `inhibit_login` enabled. + user_id: The user ID that consented + auth_result: The authenticated credentials of the newly registered user. + access_token: The access token of the newly logged in device, or + None if `inhibit_login` enabled. """ if self.hs.config.worker_app: await self._post_registration_client( @@ -698,19 +715,20 @@ async def post_registration_actions(self, user_id, auth_result, access_token): if auth_result and LoginType.TERMS in auth_result: await self._on_user_consented(user_id, self.hs.config.user_consent_version) - async def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration Args: - user_id (str): The user ID that consented. - consent_version (str): version of the policy the user has - consented to. + user_id: The user ID that consented. + consent_version: version of the policy the user has consented to. """ logger.info("%s has consented to the privacy policy", user_id) await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid( + self, user_id: str, threepid: dict, token: Optional[str] + ) -> None: """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -719,10 +737,9 @@ async def _register_email_threepid(self, user_id, threepid, token): Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.email.identity auth response - token (str|None): access_token for the user, or None if not logged - in. + user_id: id of user + threepid: m.login.email.identity auth response + token: access_token for the user, or None if not logged in. """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -748,6 +765,8 @@ async def _register_email_threepid(self, user_id, threepid, token): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) + # The token better still exist. + assert user_tuple token_id = user_tuple.token_id await self.pusher_pool.add_pusher( @@ -762,14 +781,14 @@ async def _register_email_threepid(self, user_id, threepid, token): data={}, ) - async def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: """Add a phone number as a 3pid identifier Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.msisdn auth response + user_id: id of user + threepid: m.login.msisdn auth response """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 50561ed39957..aab22b4471b3 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -329,7 +329,7 @@ async def _map_saml_response_to_user( localpart=localpart, default_display_name=displayname, bind_emails=emails, - user_agent_ips=(user_agent, ip_address), + user_agent_ips=[(user_agent, ip_address)], ) await self.store.record_user_external_id( From 2e83da8ba68b0a70718ea5fcb656f8ced26f1e56 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 15:07:58 -0500 Subject: [PATCH 3/8] Fix a bug with generating display names. --- synapse/handlers/register.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index aaf64a973772..5aac7d2a5031 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -245,6 +245,8 @@ async def register_user( else: # autogen a sequential user ID fail_count = 0 + # If a default display name is not given, generate one. + generate_display_name = default_display_name is None # This breaks on successful registration *or* errors after 10 failures. while True: # Fail after being unable to find a suitable ID a few times @@ -255,9 +257,7 @@ async def register_user( user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - # TODO This looks incorrect since default_display_name can only - # ever be None for the first iteration. - if default_display_name is None: + if generate_display_name: default_display_name = localpart try: await self.register_with_store( From 7ade91483eabe92373caef300add2232ac69356b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 15:09:10 -0500 Subject: [PATCH 4/8] Remove unnecessary assertion. --- synapse/handlers/cas_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 2b05c702880c..85304c4f4346 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -229,8 +229,6 @@ async def handle_ticket( user_agent_ips=[(user_agent, ip_address)], ) - assert registered_user_id - await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url ) From e62f4dcb78a9cfb24130450f18b871f9674df904 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 15:09:36 -0500 Subject: [PATCH 5/8] Update newsfragment. --- changelog.d/8784.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/8784.misc b/changelog.d/8784.misc index 7e651155502e..a6715532fcec 100644 --- a/changelog.d/8784.misc +++ b/changelog.d/8784.misc @@ -1 +1 @@ -Fix type hints in CAS handler. +Add additional type hints in the registration and SSO handlers. From 2819f837c96c75960c4e5a21bbd27f93c50b1525 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 15:14:16 -0500 Subject: [PATCH 6/8] Update comment. --- synapse/handlers/register.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5aac7d2a5031..0d85fd08684f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -597,9 +597,6 @@ async def register_with_store( api.constants.UserTypes, or None for a normal user. address: the IP address used to perform the registration. shadow_banned: Whether to shadow-ban the user - - Returns: - Awaitable """ if self.hs.config.worker_app: await self._register_client( From 3b9ce2a608113d4040c2c7a358352694d43b87a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Nov 2020 13:03:57 -0500 Subject: [PATCH 7/8] Re-structure CAS code to be more similar to SAML/OIDC. --- synapse/handlers/cas_handler.py | 61 +++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 85304c4f4346..f4ea0a976749 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -203,32 +203,57 @@ async def handle_ticket( args["session"] = session username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + # Pull out the user-agent and IP from the request. + user_agent = request.get_user_agent("") + ip_address = self.hs.get_ip_from_request(request) - if session: - # If there's a session then the user must already exist. - assert registered_user_id + # Get the matrix ID from the CAS username. + user_id = await self._map_cas_user_to_matrix_user( + username, user_display_name, user_agent, ip_address + ) + if session: await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + user_id, session, request, ) else: # If this not a UI auth request than there must be a redirect URL. assert client_redirect_url - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=[(user_agent, ip_address)], - ) + async def _map_cas_user_to_matrix_user( + self, + remote_user_id: str, + display_name: Optional[str], + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a CAS username, retrieve the user ID for it and possibly register the user. - await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + Args: + remote_user_id: The username from the CAS response. + display_name: The display name from the CAS response. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + """ + + localpart = map_username_to_mxid_localpart(remote_user_id) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + # If the user does not exist, register it. + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=display_name, + user_agent_ips=[(user_agent, ip_address)], ) + + return registered_user_id From 781e280d0a3ebd0d30b8edb28312e30f9d1000b7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Nov 2020 13:10:31 -0500 Subject: [PATCH 8/8] Update the changelog. --- changelog.d/8784.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/8784.misc b/changelog.d/8784.misc index a6715532fcec..18a4263398bd 100644 --- a/changelog.d/8784.misc +++ b/changelog.d/8784.misc @@ -1 +1 @@ -Add additional type hints in the registration and SSO handlers. +Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form.