From cd32e155fcf819dcf02694a87cb73a26aafcf707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Tue, 5 Sep 2023 14:05:46 +0200 Subject: [PATCH] feat: several endpoint types can be registered AuthorizationServer.register_endpoint can be called several times for one kind of endpoint. --- authlib/common/errors.py | 4 ++++ .../oauth2/rfc6749/authorization_server.py | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/authlib/common/errors.py b/authlib/common/errors.py index 084f4217..56515bab 100644 --- a/authlib/common/errors.py +++ b/authlib/common/errors.py @@ -57,3 +57,7 @@ def __call__(self, uri=None): body = dict(self.get_body()) headers = self.get_headers() return self.status_code, body, headers + + +class ContinueIteration(AuthlibBaseError): + pass diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index e5d4a67a..8b886a04 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -1,3 +1,4 @@ +from authlib.common.errors import ContinueIteration from .authenticate_client import ClientAuthentication from .requests import OAuth2Request, JsonRequest from .errors import ( @@ -186,7 +187,8 @@ def register_endpoint(self, endpoint_cls): :param endpoint_cls: A endpoint class """ - self._endpoints[endpoint_cls.ENDPOINT_NAME] = endpoint_cls(self) + endpoints = self._endpoints.setdefault(endpoint_cls.ENDPOINT_NAME, []) + endpoints.append(endpoint_cls(self)) def get_authorization_grant(self, request): """Find the authorization grant for current request. @@ -231,12 +233,15 @@ def create_endpoint_response(self, name, request=None): if name not in self._endpoints: raise RuntimeError(f'There is no "{name}" endpoint.') - endpoint = self._endpoints[name] - request = endpoint.create_endpoint_request(request) - try: - return self.handle_response(*endpoint(request)) - except OAuth2Error as error: - return self.handle_error_response(request, error) + endpoints = self._endpoints[name] + for endpoint in endpoints: + request = endpoint.create_endpoint_request(request) + try: + return self.handle_response(*endpoint(request)) + except ContinueIteration: + continue + except OAuth2Error as error: + return self.handle_error_response(request, error) def create_authorization_response(self, request=None, grant_user=None): """Validate authorization request and create authorization response.