From 1bd210b3df56813f8716a819c65a326d1c38c051 Mon Sep 17 00:00:00 2001 From: Jonathan Merlevede Date: Fri, 13 Sep 2024 19:27:30 +0200 Subject: [PATCH] Fix and test for x_token_server error bug --- tests/unit/oauth_test_utils.py | 24 +++++++++++++++++++----- trino/auth.py | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py index 5a496eaf..416e6025 100644 --- a/tests/unit/oauth_test_utils.py +++ b/tests/unit/oauth_test_utils.py @@ -54,11 +54,25 @@ def __call__(self, request, uri, response_headers): if authorization and authorization.replace("Bearer ", "") in self.tokens: return [200, response_headers, json.dumps(self.sample_post_response_data)] elif self.redirect_server is None and self.token_server is not None: - return [401, {'Www-Authenticate': f'Bearer x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [401, + { + 'Www-Authenticate': ( + 'Bearer realm="Trino", token_type="JWT", ' + f'Bearer x_token_server="{self.token_server}"' + ), + 'Basic realm': '"Trino"' + }, + ""] + return [401, + { + 'Www-Authenticate': ( + 'Bearer realm="Trino", token_type="JWT", ' + f'Bearer x_redirect_server="{self.redirect_server}", ' + f'x_token_server="{self.token_server}"' + ), + 'Basic realm': '"Trino"' + }, + ""] class GetTokenCallback: diff --git a/trino/auth.py b/trino/auth.py index e5939403..8a24ecd7 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -459,7 +459,7 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: auth_info_headers = self._parse_authenticate_header(auth_info) auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) - token_server = auth_info_headers.get('x_token_server') + token_server = auth_info_headers.get('bearer x_token_server', auth_info_headers.get('x_token_server')) if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")