diff --git a/pyromod/listen/listen.py b/pyromod/listen/listen.py index 724384a..39bf40d 100644 --- a/pyromod/listen/listen.py +++ b/pyromod/listen/listen.py @@ -19,13 +19,16 @@ """ import asyncio -from typing import Optional, Callable, Union import pyrogram +import logging + +from inspect import iscoroutinefunction +from typing import Optional, Callable, Union from enum import Enum -from ..utils import patch, patchable, PyromodConfig -loop = asyncio.get_event_loop() +from ..utils import patch, patchable, PyromodConfig +logger = logging.getLogger(__name__) class ListenerStopped(Exception): pass @@ -42,12 +45,12 @@ class ListenerTypes(Enum): @patch(pyrogram.client.Client) class Client: - @patchable + @patchable() def __init__(self, *args, **kwargs): self.listeners = {listener_type: {} for listener_type in ListenerTypes} self.old__init__(*args, **kwargs) - @patchable + @patchable() async def listen( self, identifier: tuple, @@ -62,7 +65,7 @@ async def listen( " value from pyromod.listen.ListenerTypes" ) - future = loop.create_future() + future = self.loop.create_future() future.add_done_callback( lambda f: self.stop_listening(identifier, listener_type) ) @@ -85,7 +88,7 @@ async def listen( elif PyromodConfig.throw_exceptions: raise ListenerTimeout(timeout) - @patchable + @patchable() async def ask( self, text, @@ -94,7 +97,7 @@ async def ask( listener_type=ListenerTypes.MESSAGE, timeout=None, *args, - **kwargs + **kwargs, ): if text.strip() != "": sent_message = await self.send_message(identifier[0], text, *args, **kwargs) @@ -111,7 +114,7 @@ async def ask( user_id is null, and to take precedence """ - @patchable + @patchable() def match_listener( self, data: Optional[tuple] = None, @@ -159,7 +162,7 @@ def match_identifier(pattern, identifier): return listener, identifier return None, None - @patchable + @patchable() def stop_listening( self, data: Optional[tuple] = None, @@ -186,53 +189,69 @@ def stop_listening( @patch(pyrogram.handlers.message_handler.MessageHandler) class MessageHandler: - @patchable + @patchable() def __init__(self, callback: Callable, filters=None): self.registered_handler = callback self.old__init__(self.resolve_future, filters) - @patchable + @patchable() async def check(self, client, message): - data = ( - message.chat.id, - message.from_user.id, - getattr(message, "id", getattr(message, "message_id", None)), - ) - listener = client.match_listener( - data, - ListenerTypes.MESSAGE, - )[0] - ################################# + if user := getattr(message, "from_user", None): + user = user.id + try: + listener = client.match_listener( + (message.chat.id, user, getattr(message, "id", getattr(message, "message_id", None))), + ListenerTypes.MESSAGE, + )[0] + except AttributeError as err: + logger.warning(f"Get : {err}\n\n{message}") + raise err listener_does_match = handler_does_match = False if listener: filters = listener["filters"] - listener_does_match = ( - await filters(client, message) if callable(filters) else True - ) - handler_does_match = ( - await self.filters(client, message) - if callable(self.filters) - else True - ) + if callable(filters): + if iscoroutinefunction(filters.__call__): + listener_does_match = await filters(client, message) + else: + listener_does_match = await client.loop.run_in_executor( + None, filters, client, message + ) + else: + listener_does_match = True + + if callable(self.filters): + if iscoroutinefunction(self.filters.__call__): + handler_does_match = await self.filters(client, message) + else: + handler_does_match = await client.loop.run_in_executor( + None, self.filters, client, message + ) + else: + handler_does_match = True # let handler get the chance to handle if listener # exists but its filters doesn't match return listener_does_match or handler_does_match - @patchable + @patchable() async def resolve_future(self, client, message, *args): - data = ( - message.chat.id, - message.from_user.id, - getattr(message, "id", getattr(message, "message_id", None)), - ) listener_type = ListenerTypes.MESSAGE + if not message.from_user: message_from_user_id = None else: message_from_user_id = message.from_user.id + + + data = ( + message.chat.id, + message_from_user_id, + getattr(message, "id", getattr(message, "message_id", None)), + ) + + listener, identifier = client.match_listener( data, listener_type, @@ -240,9 +259,15 @@ async def resolve_future(self, client, message, *args): listener_does_match = False if listener: filters = listener["filters"] - listener_does_match = ( - await filters(client, message) if callable(filters) else True - ) + if callable(filters): + if iscoroutinefunction(filters.__call__): + listener_does_match = await filters(client, message) + else: + listener_does_match = await client.loop.run_in_executor( + None, filters, client, message + ) + else: + listener_does_match = True if listener_does_match: if not listener["future"].done(): @@ -255,31 +280,32 @@ async def resolve_future(self, client, message, *args): @patch(pyrogram.handlers.callback_query_handler.CallbackQueryHandler) class CallbackQueryHandler: - @patchable + @patchable() def __init__(self, callback: Callable, filters=None): self.registered_handler = callback self.old__init__(self.resolve_future, filters) - @patchable + @patchable() async def check(self, client, query): - data = ( - query.message.chat.id, - query.from_user.id, - getattr(query.message, "id", getattr(query.message, "message_id", None)), - ) - listener_type = ListenerTypes.CALLBACK_QUERY - listener = client.match_listener( - data, - listener_type, - )[0] + chatID, mID = None, None + if message := getattr(query, "message", None): + chatID, mID = message.chat.id, getattr(query.message, "id", getattr(query.message, "message_id", None)) + try: + listener = client.match_listener( + (chatID, query.from_user.id, mID), + ListenerTypes.CALLBACK_QUERY, + )[0] + except AttributeError as err: + logger.warning(f"Get : {err}\n\n{message}") + raise err # managing unallowed user clicks if PyromodConfig.unallowed_click_alert: permissive_listener = client.match_listener( identifier_pattern=( - data[0], + chatID, None, - data[2], + mID, ), listener_type=listener_type, )[0] @@ -298,9 +324,17 @@ async def check(self, client, query): filters = listener["filters"] if listener else self.filters - return await filters(client, query) if callable(filters) else True + if callable(filters): + if iscoroutinefunction(filters.__call__): + return await filters(client, query) + else: + return await client.loop.run_in_executor( + None, filters, client, query + ) + else: + return True - @patchable + @patchable() async def resolve_future(self, client, query, *args): data = ( query.message.chat.id, @@ -308,8 +342,11 @@ async def resolve_future(self, client, query, *args): getattr(query.message, "id", getattr(query.message, "message_id", None)), ) listener_type = ListenerTypes.CALLBACK_QUERY + chatID, mID = None, None + if message := getattr(query, "message", None): + chatID, mID = message.chat.id, getattr(query.message, "id", getattr(query.message, "message_id", None)) listener, identifier = client.match_listener( - data, + (chatID, query.from_user.id, mID), listener_type, ) @@ -322,7 +359,7 @@ async def resolve_future(self, client, query, *args): @patch(pyrogram.types.messages_and_media.message.Message) class Message(pyrogram.types.messages_and_media.message.Message): - @patchable + @patchable() async def wait_for_click( self, from_user_id: Optional[int] = None, @@ -342,15 +379,15 @@ async def wait_for_click( @patch(pyrogram.types.user_and_chats.chat.Chat) class Chat(pyrogram.types.Chat): - @patchable + @patchable() def listen(self, *args, **kwargs): return self._client.listen((self.id, None, None), *args, **kwargs) - @patchable + @patchable() def ask(self, text, *args, **kwargs): return self._client.ask(text, (self.id, None, None), *args, **kwargs) - @patchable + @patchable() def stop_listening(self, *args, **kwargs): return self._client.stop_listening( *args, identifier_pattern=(self.id, None, None), **kwargs @@ -359,17 +396,17 @@ def stop_listening(self, *args, **kwargs): @patch(pyrogram.types.user_and_chats.user.User) class User(pyrogram.types.User): - @patchable + @patchable() def listen(self, *args, **kwargs): return self._client.listen((None, self.id, None), *args, **kwargs) - @patchable + @patchable() def ask(self, text, *args, **kwargs): return self._client.ask( text, (self.id, self.id, None), *args, **kwargs ) - @patchable + @patchable() def stop_listening(self, *args, **kwargs): return self._client.stop_listening( *args, identifier_pattern=(None, self.id, None), **kwargs diff --git a/pyromod/utils/utils.py b/pyromod/utils/utils.py index 293ff25..eb1eb3e 100644 --- a/pyromod/utils/utils.py +++ b/pyromod/utils/utils.py @@ -17,7 +17,14 @@ You should have received a copy of the GNU General Public License along with pyromod. If not, see . """ +from typing import Callable +from logging import getLogger +from inspect import iscoroutinefunction +from contextlib import contextmanager, asynccontextmanager +from pyrogram.sync import async_to_sync + +logger = getLogger(__name__) class PyromodConfig: timeout_handler = None @@ -31,18 +38,81 @@ class PyromodConfig: def patch(obj): def is_patchable(item): + # item = (name, value) + # item[0] = name + # item[1] = value return getattr(item[1], "patchable", False) def wrapper(container): for name, func in filter(is_patchable, container.__dict__.items()): old = getattr(obj, name, None) - setattr(obj, "old" + name, old) + if old is not None: # Not adding 'old' to new func + setattr(obj, "old" + name, old) + + # Worse Code + tempConf = {i: getattr(func, i, False) for i in ["is_property", "is_static", "is_context"]} + + async_to_sync(container, name) + func = getattr(container, name) + + for tKey, tValue in tempConf.items(): + setattr(func, tKey, tValue) + + if func.is_property: + func = property(func) + elif func.is_static: + func = staticmethod(func) + elif func.is_context: + if iscoroutinefunction(func.__call__): + func = asynccontextmanager(func) + else: + func = contextmanager(func) + + logger.info(f"Patch Attribute To {obj.__name__} From {container.__name__} : {name}") setattr(obj, name, func) return container return wrapper -def patchable(func): - func.patchable = True - return func +def patchable(is_property: bool = False, is_static: bool = False, is_context: bool = False) -> Callable: + """ + A decorator that marks a function as patchable. + + Usage: + + @patchable(is_property=True) + def my_property(): + ... + + @patchable(is_static=True) + def my_static_method(): + ... + + @patchable(is_context=True) + def my_context_manager(): + ... + + @patchable(is_property=False, is_static=False, is_context=False) + def my_function(): + ... + + @patchable() + def default_usage(): + ... + + Parameters: + - is_property (bool): whether the function is a property. Default is False. + - is_static (bool): whether the function is a static method. Default is False. + - is_context (bool): whether the function is a context manager. Default is False. + + Returns: + - A callable object that marks the function as patchable. + """ + def wrapper(func: Callable) -> Callable: + func.patchable = True + func.is_property = is_property + func.is_static = is_static + func.is_context = is_context + return func + return wrapper