diff --git a/fedifetcher/api/api.py b/fedifetcher/api/api.py index c6ce7083..828c6ce5 100644 --- a/fedifetcher/api/api.py +++ b/fedifetcher/api/api.py @@ -33,6 +33,15 @@ def __init__( """Initialize the API.""" raise NotImplementedError + async def cleanup(self) -> None: + await self.client.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args) -> None: + await self.cleanup() + @abstractmethod def get(self, uri: str) -> Coroutine[Any, Any, dict]: """Get an object by URI.""" @@ -169,7 +178,11 @@ class FederationInterface: _equipped_api: API - def __init__(self, url: str) -> None: + def __init__(self, + domain: str, + token: str | None = None, + pgupdater: Any | None = None, + ) -> None: """Initialize the API.""" def _normalize_url(_url: str) -> str: """Normalize a URL.""" @@ -181,7 +194,7 @@ def _normalize_url(_url: str) -> str: if _url.count("/") > slashes_in_protocol_prefix: _url = _url[:_url.find("/", 8)] return _url - url = _normalize_url(url) + domain = _normalize_url(domain) temporary_client_session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=60), headers={ @@ -189,7 +202,7 @@ def _normalize_url(_url: str) -> str: }, ) temporary_client = HttpMethod( - url, + domain, temporary_client_session, ) nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0")) @@ -197,25 +210,25 @@ def _normalize_url(_url: str) -> str: wellknown_nodeinfo = asyncio.run(temporary_client.get("/.well-known/nodeinfo")) if not wellknown_nodeinfo: raise NotImplementedError - url = _normalize_url(wellknown_nodeinfo["links"][0]["href"]) + domain = _normalize_url(wellknown_nodeinfo["links"][0]["href"]) temporary_client = HttpMethod( - url, + domain, temporary_client_session, ) nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0")) if not nodeinfo: wellknown_hostmeta = asyncio.run(temporary_client.get("/.well-known/host-meta")) if wellknown_hostmeta and wellknown_hostmeta["server"]: - url = _normalize_url(wellknown_hostmeta["server"]) + domain = _normalize_url(wellknown_hostmeta["server"]) temporary_client = HttpMethod( - url, + domain, temporary_client_session, ) nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0")) software_name = nodeinfo["software"]["name"] if nodeinfo else None if software_name == "mastodon": from fedifetcher.api.mastodon import Mastodon - equippable_api = Mastodon(url) + equippable_api = Mastodon(domain, token, pgupdater) elif software_name == "misskey": raise NotImplementedError else: @@ -223,6 +236,10 @@ def _normalize_url(_url: str) -> str: raise NotImplementedError(msg) self._equipped_api: API = equippable_api + async def cleanup(self) -> None: + """Clean up.""" + await self._equipped_api.cleanup() + def get(self, uri: str) -> Coroutine[Any, Any, dict]: """Get an object by URI.""" return self._equipped_api.get(uri) @@ -340,3 +357,20 @@ def get_home_timeline( """Get the home timeline.""" return self._equipped_api.get_home_timeline(limit) +class FederationInterfaceManager: + """Manage federation interfaces.""" + + def __init__(self) -> None: + """Initialize the manager.""" + self._interfaces: dict[str, FederationInterface] = {} + + def get_interface(self, domain: str) -> FederationInterface: + """Get an interface.""" + if domain not in self._interfaces: + self._interfaces[domain] = FederationInterface(domain) + return self._interfaces[domain] + + async def cleanup(self) -> None: + """Clean up.""" + for interface in self._interfaces.values(): + await interface.cleanup() diff --git a/fedifetcher/find_context.py b/fedifetcher/find_context.py index 0adeb97d..fbe7d751 100644 --- a/fedifetcher/find_context.py +++ b/fedifetcher/find_context.py @@ -6,8 +6,7 @@ from typing import TYPE_CHECKING from fedifetcher import getter_wrappers, parsers -from fedifetcher.api.api import ApiError -from fedifetcher.api.mastodon import api_mastodon +from fedifetcher.api.api import API, ApiError from fedifetcher.api.postgresql import PostgreSQLUpdater if TYPE_CHECKING: @@ -16,10 +15,11 @@ async def add_post_with_context( post: dict[str, str], - home_server: str, - access_token: str, + api: API, + # home_server: str, + # access_token: str, external_tokens: dict[str, str], - pgupdater: PostgreSQLUpdater, + # pgupdater: PostgreSQLUpdater, arguments: Namespace, ) -> bool: """Add the given post to the server. @@ -38,11 +38,7 @@ async def add_post_with_context( bool: True if the post was added successfully, False otherwise. """ try: - await api_mastodon.Mastodon( - home_server, - access_token, - pgupdater, - ).get(post["url"]) + await api.get(post["url"]) if ("replies_count" in post or "in_reply_to_id" in post) and getattr( arguments, "backfill_with_context", @@ -51,42 +47,37 @@ async def add_post_with_context( parsed_urls: dict[str, tuple[str | None, str | None]] = {} parsed = parsers.post(post["url"], parsed_urls) if parsed is not None and parsed[0] is not None: + if api.client.pgupdater is None or api.client.token is None: + msg = "No PostgreSQL updater, or no token" + raise ApiError(msg) known_context_urls = await getter_wrappers.get_all_known_context_urls( - home_server, + api, [post], parsed_urls, external_tokens, - pgupdater, - access_token, ) ( await add_context_urls_wrapper( - home_server, - access_token, + api, known_context_urls, - pgupdater, ) ) return True except ApiError: - logging.debug(f"Failed to add {post['url']} to {home_server}") + logging.debug(f"Failed to add {post['url']} to {api.client.api_base_url}") return False async def add_context_urls_wrapper( - home_server: str, - access_token: str, + api: API, context_urls: Iterable[str], - pgupdater: PostgreSQLUpdater, ) -> None: """Add the given toot URLs to the server. Args: ---- - home_server: The server to add the toots to. - access_token: The access token to use to add the toots. + api: The API to use to add the statuses. context_urls: The list of toot URLs to add. - pgupdater: The PostgreSQL updater. """ list_of_context_urls = list(context_urls) if len(list_of_context_urls) == 0: @@ -96,7 +87,10 @@ async def add_context_urls_wrapper( failed = 0 already_added = 0 posts_to_fetch = [] - cached_posts: dict[str, Status | None] = pgupdater.get_dict_from_cache( + if api.client.pgupdater is None: + msg = "No PostgreSQL updater" + raise ApiError(msg) + cached_posts: dict[str, Status | None] = api.client.pgupdater.get_dict_from_cache( list_of_context_urls, ) logging.debug(f"Got {len(cached_posts)} cached posts") @@ -116,24 +110,18 @@ async def add_context_urls_wrapper( if posts_to_fetch: logging.debug(f"Fetching {len(posts_to_fetch)} posts") - max_concurrent_tasks = 10 - semaphore = asyncio.Semaphore(max_concurrent_tasks) tasks = [] for url in posts_to_fetch: logging.debug(f"Adding {url} to home server") tasks.append( - api_mastodon.Mastodon( - home_server, - access_token, - pgupdater, - ).get(url, semaphore), + api.get(url), ) for task in asyncio.as_completed(tasks): try: result: Status = await task logging.debug(f"Got {result}") count += 1 - pgupdater.cache_status(result) + api.client.pgupdater.cache_status(result) except ApiError: failed += 1 diff --git a/fedifetcher/find_trending_posts.py b/fedifetcher/find_trending_posts.py index 8cbb1f89..a911c869 100644 --- a/fedifetcher/find_trending_posts.py +++ b/fedifetcher/find_trending_posts.py @@ -7,6 +7,7 @@ from collections.abc import Callable from fedifetcher import parsers +from fedifetcher.api.api import API from fedifetcher.api.mastodon import api_mastodon from fedifetcher.api.postgresql import PostgreSQLUpdater @@ -144,11 +145,9 @@ def get_domains_to_fetch(self) -> list[str]: async def find_trending_posts( # noqa: C901 - home_server: str, - home_token: str, + api: API, external_feeds: list[str], external_tokens: dict[str, str], - pgupdater: PostgreSQLUpdater, ) -> list[dict[str, str]]: """Pull trending posts from a list of Mastodon servers, using tokens.""" msg = f"Finding trending posts from {len(external_feeds)} domains:" @@ -220,51 +219,54 @@ async def find_trending_posts( # noqa: C901 ) logging.info(f"\033[1;34m{msg}\033[0m") max_concurrent_requests = 10 - semaphore = asyncio.Semaphore(max_concurrent_requests) promises_container = {} promises = [] - for status_id in remember_to_find_me[fetch_domain]: - if ( - str(status_id) not in trending_posts_dict - or "original" not in trending_posts_dict[str(status_id)] - ): - promise = asyncio.ensure_future( - api_mastodon.Mastodon( + async with asyncio.Semaphore(max_concurrent_requests) as semaphore: + for status_id in remember_to_find_me[fetch_domain]: + if ( + str(status_id) not in trending_posts_dict + or "original" not in trending_posts_dict[str(status_id)] + ): + async with api_mastodon.Mastodon( fetch_domain, external_tokens.get(fetch_domain), - ).get_status_by_id(status_id, semaphore), - ) - promises_container[status_id] = promise - promises.append(promise) - await asyncio.gather(*promises) - for _status_id, future in promises_container.items(): - original_post = future.result() - if original_post: - add_post_to_dict(original_post, fetch_domain, trending_posts_dict) - else: - logging.warning(f"Couldn't find {_status_id} from {fetch_domain}") + ) as api: + promise = asyncio.create_task(api.get_status_by_id(status_id, semaphore)) + promises_container[status_id] = promise + promises.append(promise) + await asyncio.gather(*promises) + for _status_id, future in promises_container.items(): + original_post = future.result() + if original_post: + add_post_to_dict(original_post, fetch_domain, trending_posts_dict) + else: + logging.warning(f"Couldn't find {_status_id} from {fetch_domain}") logging.info(f"Fetching aux posts from {len(trending_posts_dict.keys())} domains") - await aux_domain_fetcher.do_aux_fetches(trending_posts_dict, pgupdater) + if api.client.token is None or api.client.pgupdater is None: + msg = "No token or PostgreSQL updater" + raise api_mastodon.ApiError(msg) + + await aux_domain_fetcher.do_aux_fetches(trending_posts_dict, api.client.pgupdater) updated_trending_posts_dict = await update_local_status_ids( trending_posts_dict, - home_server, - home_token, - pgupdater, + api.client.api_base_url, + api.client.token, + api.client.pgupdater, ) """Update the status stats with the trending posts.""" - if pgupdater: + if api.client.pgupdater is not None: for post in updated_trending_posts_dict.values(): local_status_id = post.get("local_status_id") if local_status_id: - pgupdater.queue_status_update( + api.client.pgupdater.queue_status_update( local_status_id, int(post["reblogs_count"]), int(post["favourites_count"]), ) - pgupdater.commit_status_updates() + api.client.pgupdater.commit_status_updates() return list( api_mastodon.filter_language(updated_trending_posts_dict.values(), "en"), diff --git a/fedifetcher/getter_wrappers.py b/fedifetcher/getter_wrappers.py index f2166956..242e3f6e 100644 --- a/fedifetcher/getter_wrappers.py +++ b/fedifetcher/getter_wrappers.py @@ -8,6 +8,7 @@ from typing import Any, cast import requests +from fedifetcher.api.api import API from fedifetcher.api.mastodon import api_mastodon from fedifetcher.api.postgresql import PostgreSQLUpdater @@ -45,7 +46,8 @@ async def get_notification_users( since = datetime.now(datetime.now(UTC).astimezone().tzinfo) - timedelta( hours=max_age, ) - notifications = await api_mastodon.Mastodon(server, access_token).get_notifications( + async with api_mastodon.Mastodon(server, access_token) as api: + notifications = await api.get_notifications( since_id=str(int(since.timestamp())), ) notification_users = [] @@ -233,12 +235,10 @@ async def get_all_reply_toots( async def get_all_known_context_urls( - home_server: str, + api: API, reply_toots: list[dict[str, str]], parsed_urls: dict[str, tuple[str | None, str | None]], external_tokens: dict[str, str], - pgupdater: PostgreSQLUpdater, - home_server_token: str, ) -> Iterable[str]: known_context_urls = [] if reply_toots is not None: @@ -255,6 +255,9 @@ async def get_all_known_context_urls( logging.debug(f"Getting context for {len(toots_to_get_context_for)} toots") max_concurrent_tasks = 10 semaphore = asyncio.Semaphore(max_concurrent_tasks) + if api.client.pgupdater is None or api.client.token is None: + msg = "No PostgreSQL updater, or no token" + raise api_mastodon.ApiError(msg) tasks = [ asyncio.ensure_future( post_content( @@ -262,9 +265,9 @@ async def get_all_known_context_urls( x[0][1], x[1], external_tokens, - pgupdater, - home_server, - home_server_token, + api.client.pgupdater, + api.client.api_base_url, + api.client.token, semaphore, ), )