Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Support trying multiple localparts for OpenID Connect. (#8801)
Browse files Browse the repository at this point in the history
Abstracts the SAML and OpenID Connect code which attempts to regenerate
the localpart of a matrix ID if it is already in use.
  • Loading branch information
clokep authored Nov 25, 2020
1 parent f38676d commit 4fd222a
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 136 deletions.
1 change: 1 addition & 0 deletions changelog.d/8801.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
11 changes: 10 additions & 1 deletion docs/sso_mapping_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
Expand Down
120 changes: 50 additions & 70 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
Expand All @@ -35,15 +36,10 @@

from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import (
JsonDict,
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
)
from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder

if TYPE_CHECKING:
Expand Down Expand Up @@ -869,73 +865,51 @@ async def _map_userinfo_to_user(
# to be strings.
remote_user_id = str(remote_user_id)

# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
mapper_signature = inspect.signature(
self._user_mapping_provider.map_user_attributes
)
if previously_registered_user_id:
return previously_registered_user_id
supports_failures = "failures" in mapper_signature.parameters

# Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Call the mapping provider to map the OIDC userinfo and token to user attributes.
logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
This is backwards compatibility for abstraction for the SSO handler.
"""
if supports_failures:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token, failures
)
else:
# If the mapping provider does not support processing failures,
# do not continually generate the same Matrix ID since it will
# continue to already be in use. Note that the error raised is
# arbitrary and will get turned into a MappingException.
if failures:
raise RuntimeError(
"Mapping provider does not support de-duplicating Matrix IDs"
)

localpart = attributes["localpart"]
if not localpart:
raise MappingException(
"Error parsing OIDC response: OIDC mapping provider plugin "
"did not return a localpart value"
)
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
userinfo, token
)

user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
registered_user_id = next(iter(users))
elif user_id in users:
registered_user_id = user_id
else:
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, list(users.keys())
)
)
else:
# This mxid is taken
raise MappingException("mxid '{}' is already taken".format(user_id))
else:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))

# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=[(user_agent, ip_address)],
)
return UserAttributes(**attributes)

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
self._allow_existing_users,
)
return registered_user_id


UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")

Expand Down Expand Up @@ -978,13 +952,15 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str:
raise NotImplementedError()

async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
failures: How many times a call to this function with this
UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
Expand Down Expand Up @@ -1084,13 +1060,17 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo[self._config.subject_claim]

async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()

# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)

# Append suffix integer if last call to this function failed to produce
# a usable mxid.
localpart += str(failures) if failures else ""

display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
Expand All @@ -1100,7 +1080,7 @@ async def map_user_attributes(
if display_name == "":
display_name = None

return UserAttribute(localpart=localpart, display_name=display_name)
return UserAttributeDict(localpart=localpart, display_name=display_name)

async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
Expand Down
91 changes: 28 additions & 63 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
Expand Down Expand Up @@ -250,14 +249,26 @@ async def _map_saml_response_to_user(
"Failed to extract remote user id from SAML response"
)

with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
"""
Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails"),
)
if previously_registered_user_id:
return previously_registered_user_id

with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
Expand All @@ -284,59 +295,13 @@ async def _map_saml_response_to_user(
)
return registered_user_id

# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i, client_redirect_url=client_redirect_url,
)

logger.debug(
"Retrieved SAML attributes from user mapping provider: %s "
"(attempt %d)",
attribute_dict,
i,
)

localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)

displayname = attribute_dict.get("displayname")
emails = attribute_dict.get("emails", [])

# Check if this mxid already exists
if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)

# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))

logger.debug("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=[(user_agent, ip_address)],
)

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
)
return registered_user_id

def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
Expand Down Expand Up @@ -451,11 +416,11 @@ def saml_response_to_user_attributes(
)

# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
localpart = self._mxid_mapper(mxid_source)

# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# a usable mxid.
localpart += str(failures) if failures else ""

# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
Expand Down
Loading

0 comments on commit 4fd222a

Please sign in to comment.