diff --git a/ldap_auth_provider.py b/ldap_auth_provider.py index 1189d71..087034e 100644 --- a/ldap_auth_provider.py +++ b/ldap_auth_provider.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads +from twisted.internet import threads import ldap3 @@ -88,8 +88,7 @@ def __init__(self, config, account_handler): def get_supported_login_types(self): return {'m.login.password': ('password',)} - @defer.inlineCallbacks - def check_auth(self, username, login_type, login_dict): + async def check_auth(self, username, login_type, login_dict): """ Attempt to authenticate a user against an LDAP Server and register an account if none exists. @@ -103,7 +102,7 @@ def check_auth(self, username, login_type, login_dict): # an anonymous authorization state and not suitable for user # authentication. if not password: - defer.returnValue(False) + return False if username.startswith("@") and ":" in username: # username is of the form @foo:bar.com @@ -122,7 +121,7 @@ def check_auth(self, username, login_type, login_dict): uid_value = login + "@" + domain default_display_name = login except ActiveDirectoryUPNException: - defer.returnValue(False) + return False try: tls = ldap3.Tls(validate=ssl.CERT_REQUIRED) @@ -140,7 +139,7 @@ def check_auth(self, username, login_type, login_dict): value=uid_value, base=self.ldap_base ) - result, conn = yield self._ldap_simple_bind( + result, conn = await self._ldap_simple_bind( server=server, bind_dn=bind_dn, password=password ) logger.debug( @@ -150,10 +149,10 @@ def check_auth(self, username, login_type, login_dict): conn ) if not result: - defer.returnValue(False) + return False elif self.ldap_mode == LDAPMode.SEARCH: filters = [(self.ldap_attributes["uid"], uid_value)] - result, conn, _ = yield self._ldap_authenticated_search( + result, conn, _ = await self._ldap_authenticated_search( server=server, password=password, filters=filters ) logger.debug( @@ -163,7 +162,7 @@ def check_auth(self, username, login_type, login_dict): conn ) if not result: - defer.returnValue(False) + return False else: # pragma: no cover raise RuntimeError( 'Invalid LDAP mode specified: {mode}'.format( @@ -181,17 +180,17 @@ def check_auth(self, username, login_type, login_dict): "Authentication method yielded no LDAP connection, " "aborting!" ) - defer.returnValue(False) + return False # Get full user id from localpart user_id = self.account_handler.get_qualified_user_id(localpart) # check if user with user_id exists - if (yield self.account_handler.check_user_exists(user_id)): + if await self.account_handler.check_user_exists(user_id): # exists, authentication complete if hasattr(conn, "unbind"): - yield threads.deferToThread(conn.unbind) - defer.returnValue(user_id) + await threads.deferToThread(conn.unbind) + return user_id else: # does not exist, register @@ -200,7 +199,7 @@ def check_auth(self, username, login_type, login_dict): # existing ldap connection filters = [(self.ldap_attributes['uid'], uid_value)] - result, conn, response = yield self._ldap_authenticated_search( + result, conn, response = await self._ldap_authenticated_search( server=server, password=password, filters=filters, ) @@ -222,18 +221,17 @@ def check_auth(self, username, login_type, login_dict): mail = None # Register the user - user_id = yield self.register_user(localpart, display_name, mail) + user_id = await self.register_user(localpart, display_name, mail) - defer.returnValue(user_id) + return user_id - defer.returnValue(False) + return False except ldap3.core.exceptions.LDAPException as e: logger.warning("Error during ldap authentication: %s", e) - defer.returnValue(False) + return False - @defer.inlineCallbacks - def check_3pid_auth(self, medium, address, password): + async def check_3pid_auth(self, medium, address, password): """ Handle authentication against thirdparty login types, such as email Args: @@ -248,11 +246,11 @@ def check_3pid_auth(self, medium, address, password): if self.ldap_mode != LDAPMode.SEARCH: logger.debug("3PID LDAP login/register attempted but LDAP search mode " "not enabled. Bailing.") - defer.returnValue(None) + return None # We currently only support email if medium != "email": - defer.returnValue(None) + return None # Talk to LDAP and check if this email/password combo is correct try: @@ -265,7 +263,7 @@ def check_3pid_auth(self, medium, address, password): ) search_filter = [(self.ldap_attributes["mail"], address)] - result, conn, response = yield self._ldap_authenticated_search( + result, conn, response = await self._ldap_authenticated_search( server=server, password=password, filters=search_filter, ) @@ -279,10 +277,10 @@ def check_3pid_auth(self, medium, address, password): # Close connection if hasattr(conn, "unbind"): - yield threads.deferToThread(conn.unbind) + await threads.deferToThread(conn.unbind) if not result: - defer.returnValue(None) + return None # Extract the username from the search response from the LDAP server localpart = response["attributes"].get( @@ -306,16 +304,15 @@ def check_3pid_auth(self, medium, address, password): givenName = givenName[0] if len(givenName) == 1 else localpart # Register the user - user_id = yield self.register_user(localpart, givenName, address) + user_id = await self.register_user(localpart, givenName, address) - defer.returnValue(user_id) + return user_id except ldap3.core.exceptions.LDAPException as e: logger.warning("Error during ldap authentication: %s", e) raise - @defer.inlineCallbacks - def register_user(self, localpart, name, email_address): + async def register_user(self, localpart, name, email_address): """Register a Synapse user, first checking if they exist. Args: @@ -329,9 +326,9 @@ def register_user(self, localpart, name, email_address): # Get full user id from localpart user_id = self.account_handler.get_qualified_user_id(localpart) - if (yield self.account_handler.check_user_exists(user_id)): + if await self.account_handler.check_user_exists(user_id): # exists, authentication complete - defer.returnValue(user_id) + return user_id # register an email address if one exists emails = [email_address] if email_address is not None else [] @@ -341,14 +338,14 @@ def register_user(self, localpart, name, email_address): # from password providers if parse_version(synapse.__version__) <= parse_version("0.99.3"): user_id, access_token = ( - yield self.account_handler.register( + await self.account_handler.register( localpart=localpart, displayname=name, ) ) else: # If Synapse has support, bind emails user_id, access_token = ( - yield self.account_handler.register( + await self.account_handler.register( localpart=localpart, displayname=name, emails=emails, ) ) @@ -358,7 +355,7 @@ def register_user(self, localpart, name, email_address): user_id, ) - defer.returnValue(user_id) + return user_id @staticmethod def parse_config(config): @@ -407,8 +404,7 @@ class _LdapConfig(object): return ldap_config - @defer.inlineCallbacks - def _ldap_simple_bind(self, server, bind_dn, password): + async def _ldap_simple_bind(self, server, bind_dn, password): """ Attempt a simple bind with the credentials given by the user against the LDAP server. @@ -420,7 +416,7 @@ def _ldap_simple_bind(self, server, bind_dn, password): try: # bind with the the local user's ldap credentials - conn = yield threads.deferToThread( + conn = await threads.deferToThread( ldap3.Connection, server, bind_dn, password, authentication=LDAP_AUTH_SIMPLE, @@ -432,33 +428,32 @@ def _ldap_simple_bind(self, server, bind_dn, password): ) if self.ldap_start_tls: - yield threads.deferToThread(conn.open) - yield threads.deferToThread(conn.start_tls) + await threads.deferToThread(conn.open) + await threads.deferToThread(conn.start_tls) logger.debug( "Upgraded LDAP connection in simple bind mode through " "StartTLS: %s", conn ) - if (yield threads.deferToThread(conn.bind)): + if await threads.deferToThread(conn.bind): # GOOD: bind okay logger.debug("LDAP Bind successful in simple bind mode.") - defer.returnValue((True, conn)) + return (True, conn) # BAD: bind failed logger.info( "Binding against LDAP failed for '%s' failed: %s", bind_dn, conn.result['description'] ) - yield threads.deferToThread(conn.unbind) - defer.returnValue((False, None)) + await threads.deferToThread(conn.unbind) + return (False, None) except ldap3.core.exceptions.LDAPException as e: logger.warning("Error during LDAP authentication: %s", e) raise - @defer.inlineCallbacks - def _ldap_authenticated_search(self, server, password, filters): + async def _ldap_authenticated_search(self, server, password, filters): """Attempt to login with the preconfigured bind_dn and then continue searching and filtering within the base_dn. @@ -480,7 +475,7 @@ def _ldap_authenticated_search(self, server, password, filters): """ try: - conn = yield threads.deferToThread( + conn = await threads.deferToThread( ldap3.Connection, server, self.ldap_bind_dn, @@ -493,21 +488,21 @@ def _ldap_authenticated_search(self, server, password, filters): ) if self.ldap_start_tls: - yield threads.deferToThread(conn.open) - yield threads.deferToThread(conn.start_tls) + await threads.deferToThread(conn.open) + await threads.deferToThread(conn.start_tls) logger.debug( "Upgraded LDAP connection in search mode through " "StartTLS: %s", conn ) - if not (yield threads.deferToThread(conn.bind)): + if not await threads.deferToThread(conn.bind): logger.warning( "Binding against LDAP with `bind_dn` failed: %s", conn.result['description'] ) - yield threads.deferToThread(conn.unbind) - defer.returnValue((False, None, None)) + await threads.deferToThread(conn.unbind) + return (False, None, None) # Construct search filter query = "" @@ -529,7 +524,7 @@ def _ldap_authenticated_search(self, server, password, filters): "LDAP search filter: %s", query ) - yield threads.deferToThread( + await threads.deferToThread( conn.search, search_base=self.ldap_base, search_filter=query, @@ -555,12 +550,12 @@ def _ldap_authenticated_search(self, server, password, filters): # unbind and simple bind with user_dn to verify the password # Note: do not use rebind(), for some reason it did not verify # the password for me! - yield threads.deferToThread(conn.unbind) - result, conn = yield self._ldap_simple_bind( + await threads.deferToThread(conn.unbind) + result, conn = await self._ldap_simple_bind( server=server, bind_dn=user_dn, password=password ) - defer.returnValue((result, conn, responses[0])) + return (result, conn, responses[0]) else: # BAD: found 0 or > 1 results, abort! if len(responses) == 0: @@ -573,9 +568,9 @@ def _ldap_authenticated_search(self, server, password, filters): "LDAP search returned too many (%s) results for '%s'", len(responses), filters ) - yield threads.deferToThread(conn.unbind) + await threads.deferToThread(conn.unbind) - defer.returnValue((False, None, None)) + return (False, None, None) except ldap3.core.exceptions.LDAPException as e: logger.warning("Error during LDAP authentication: %s", e) diff --git a/tests/__init__.py b/tests/__init__.py index 10e8599..b3ca683 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,9 @@ +from asyncio.futures import Future +from typing import Any, Awaitable + from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ServerFactory -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.python.components import registerAdapter from ldaptor.inmemory import fromLDIFFile from ldaptor.interfaces import IConnectedLDAPEntry @@ -106,12 +109,11 @@ """ -@defer.inlineCallbacks -def _create_db(): +async def _create_db(): f = BytesIO(LDIF) - db = yield fromLDIFFile(f) + db = await fromLDIFFile(f) f.close() - defer.returnValue(db) + return db class _LDAPServerFactory(ServerFactory): @@ -156,20 +158,19 @@ def close(self): ) -@defer.inlineCallbacks -def create_ldap_server(): +async def create_ldap_server(): "Returns a context manager that represents the LDAP server." - db = yield _create_db() + db = await _create_db() factory = _LDAPServerFactory(db) factory.debug = True # We just pick an arbitrary port to listen on. serverEndpointStr = "tcp:0" e = serverFromString(reactor, serverEndpointStr) - listener = yield e.listen(factory) + listener = await e.listen(factory) - defer.returnValue(_LdapServer(listener)) + return _LdapServer(listener) def create_auth_provider(server, account_handler, config=None): @@ -192,6 +193,17 @@ def create_auth_provider(server, account_handler, config=None): return LdapAuthProvider(config, account_handler=account_handler) +def make_awaitable(result: Any) -> Awaitable[Any]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future = Future() + future.set_result(result) + return future + + def get_qualified_user_id(username): if not username.startswith('@'): return "@%s:test" % username diff --git a/tests/test_ad.py b/tests/test_ad.py index 6e44335..dadddab 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -12,13 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from twisted.internet.defer import ensureDeferred from twisted.trial import unittest from twisted.internet import defer from mock import Mock -from . import create_ldap_server, create_auth_provider, get_qualified_user_id +from . import ( + create_ldap_server, + create_auth_provider, + get_qualified_user_id, + make_awaitable, +) import logging logging.basicConfig() @@ -27,9 +32,9 @@ class AbstractLdapActiveDirectoryTestCase(): @defer.inlineCallbacks def setUp(self): - self.ldap_server = yield create_ldap_server() + self.ldap_server = yield ensureDeferred(create_ldap_server()) account_handler = Mock(spec_set=["check_user_exists", "get_qualified_user_id"]) - account_handler.check_user_exists.return_value = True + account_handler.check_user_exists.return_value = make_awaitable(True) account_handler.get_qualified_user_id = get_qualified_user_id self.auth_provider = create_auth_provider( @@ -62,59 +67,59 @@ def tearDown(self): class LdapActiveDirectoryTestCase(AbstractLdapActiveDirectoryTestCase, unittest.TestCase): @defer.inlineCallbacks def test_correct_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "main.example.org\\mainuser", 'm.login.password', {"password": "abracadabra"} - ) + )) self.assertEqual(result, "@mainuser/main.example.org:test") - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\nonmainuser", 'm.login.password', {"password": "simsalabim"} - ) + )) self.assertEqual(result, "@nonmainuser/subsidiary.example.org:test") - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\mainuser", 'm.login.password', {"password": "changeit"} - ) + )) self.assertEqual(result, "@mainuser/subsidiary.example.org:test") @defer.inlineCallbacks def test_single_email(self): - result = yield self.auth_provider.check_3pid_auth( + result = yield ensureDeferred(self.auth_provider.check_3pid_auth( "email", "mainuser@main.example.org", "abracadabra", - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_incorrect_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "main.example.org\\mainuser", 'm.login.password', {"password": "bruteforce"} - ) + )) self.assertFalse(result) - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\mainuser", 'm.login.password', {"password": "abracadabra"} - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_email_login(self): - result = yield self.auth_provider.check_3pid_auth( + result = yield ensureDeferred(self.auth_provider.check_3pid_auth( "email", "uniqueuser@main.example.org", "nothing", - ) + )) self.assertEqual(result, "@uniqueuser/main.example.org:test") @@ -130,55 +135,55 @@ def getConfig(self): @defer.inlineCallbacks def test_correct_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "mainuser", 'm.login.password', {"password": "abracadabra"} - ) + )) self.assertEqual(result, "@mainuser:test") - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "main.example.org\\mainuser", 'm.login.password', {"password": "abracadabra"} - ) + )) self.assertEqual(result, "@mainuser:test") - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\nonmainuser", 'm.login.password', {"password": "simsalabim"} - ) + )) self.assertEqual(result, "@nonmainuser/subsidiary.example.org:test") - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\mainuser", 'm.login.password', {"password": "changeit"} - ) + )) self.assertEqual(result, "@mainuser/subsidiary.example.org:test") @defer.inlineCallbacks def test_incorrect_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "mainuser", 'm.login.password', {"password": "changeit"} - ) + )) self.assertFalse(result) - result = yield self.auth_provider.check_auth( + result = yield ensureDeferred(self.auth_provider.check_auth( "subsidiary.example.org\\mainuser", 'm.login.password', {"password": "abracadabra"} - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_email_login(self): - result = yield self.auth_provider.check_3pid_auth( + result = yield ensureDeferred(self.auth_provider.check_3pid_auth( "email", "uniqueuser@main.example.org", "nothing", - ) + )) self.assertEqual(result, "@uniqueuser:test") diff --git a/tests/test_simple.py b/tests/test_simple.py index c9d8872..efc404a 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -18,7 +18,12 @@ from mock import Mock -from . import create_ldap_server, create_auth_provider, get_qualified_user_id +from . import ( + create_ldap_server, + create_auth_provider, + get_qualified_user_id, + make_awaitable, +) import logging logging.basicConfig() @@ -27,9 +32,9 @@ class LdapSimpleTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.ldap_server = yield create_ldap_server() + self.ldap_server = yield defer.ensureDeferred(create_ldap_server()) account_handler = Mock(spec_set=["check_user_exists", "get_qualified_user_id"]) - account_handler.check_user_exists.return_value = True + account_handler.check_user_exists.return_value = make_awaitable(True) account_handler.get_qualified_user_id = get_qualified_user_id self.auth_provider = create_auth_provider( @@ -51,47 +56,47 @@ def tearDown(self): @defer.inlineCallbacks def test_unknown_user(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "non_existent", 'm.login.password', {"password": "password"} - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_incorrect_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "bob", 'm.login.password', {"password": "wrong_password"} - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_correct_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "bob", 'm.login.password', {"password": "secret"} - ) + )) self.assertEqual(result, "@bob:test") @defer.inlineCallbacks def test_no_pwd(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "bob", 'm.login.password', {"password": ""} - ) + )) self.assertFalse(result) class LdapSearchTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.ldap_server = yield create_ldap_server() + self.ldap_server = yield defer.ensureDeferred(create_ldap_server()) account_handler = Mock(spec_set=["check_user_exists", "get_qualified_user_id"]) - account_handler.check_user_exists.return_value = True + account_handler.check_user_exists.return_value = make_awaitable(True) account_handler.get_qualified_user_id = get_qualified_user_id self.auth_provider = create_auth_provider( @@ -115,27 +120,27 @@ def tearDown(self): @defer.inlineCallbacks def test_correct_pwd_search_mode(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "bob", 'm.login.password', {"password": "secret"} - ) + )) self.assertEqual(result, "@bob:test") @defer.inlineCallbacks def test_incorrect_pwd_search_mode(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "bob", 'm.login.password', {"password": "wrong_password"} - ) + )) self.assertFalse(result) @defer.inlineCallbacks def test_unknown_user_search_mode(self): - result = yield self.auth_provider.check_auth( + result = yield defer.ensureDeferred(self.auth_provider.check_auth( "foobar", 'm.login.password', {"password": "some_password"} - ) + )) self.assertFalse(result)