Skip to content

Commit

Permalink
chore: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Teqed committed Sep 18, 2023
1 parent 8d73b08 commit 83d4998
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 76 deletions.
50 changes: 42 additions & 8 deletions fedifetcher/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def __init__(
"""Initialize the API."""
raise NotImplementedError

async def cleanup(self) -> None:
await self.client.session.close()

async def __aenter__(self):
return self

async def __aexit__(self, *args) -> None:
await self.cleanup()

@abstractmethod
def get(self, uri: str) -> Coroutine[Any, Any, dict]:
"""Get an object by URI."""
Expand Down Expand Up @@ -169,7 +178,11 @@ class FederationInterface:

_equipped_api: API

def __init__(self, url: str) -> None:
def __init__(self,
domain: str,
token: str | None = None,
pgupdater: Any | None = None,
) -> None:
"""Initialize the API."""
def _normalize_url(_url: str) -> str:
"""Normalize a URL."""
Expand All @@ -181,48 +194,52 @@ def _normalize_url(_url: str) -> str:
if _url.count("/") > slashes_in_protocol_prefix:
_url = _url[:_url.find("/", 8)]
return _url
url = _normalize_url(url)
domain = _normalize_url(domain)
temporary_client_session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60),
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 +https://github.com/Teqed Meowstodon/1.0.0", # noqa: E501
},
)
temporary_client = HttpMethod(
url,
domain,
temporary_client_session,
)
nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0"))
if not nodeinfo:
wellknown_nodeinfo = asyncio.run(temporary_client.get("/.well-known/nodeinfo"))
if not wellknown_nodeinfo:
raise NotImplementedError
url = _normalize_url(wellknown_nodeinfo["links"][0]["href"])
domain = _normalize_url(wellknown_nodeinfo["links"][0]["href"])
temporary_client = HttpMethod(
url,
domain,
temporary_client_session,
)
nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0"))
if not nodeinfo:
wellknown_hostmeta = asyncio.run(temporary_client.get("/.well-known/host-meta"))
if wellknown_hostmeta and wellknown_hostmeta["server"]:
url = _normalize_url(wellknown_hostmeta["server"])
domain = _normalize_url(wellknown_hostmeta["server"])
temporary_client = HttpMethod(
url,
domain,
temporary_client_session,
)
nodeinfo = asyncio.run(temporary_client.get("/nodeinfo/2.0"))
software_name = nodeinfo["software"]["name"] if nodeinfo else None
if software_name == "mastodon":
from fedifetcher.api.mastodon import Mastodon
equippable_api = Mastodon(url)
equippable_api = Mastodon(domain, token, pgupdater)
elif software_name == "misskey":
raise NotImplementedError
else:
msg = f"Unknown software name: {software_name}"
raise NotImplementedError(msg)
self._equipped_api: API = equippable_api

async def cleanup(self) -> None:
"""Clean up."""
await self._equipped_api.cleanup()

def get(self, uri: str) -> Coroutine[Any, Any, dict]:
"""Get an object by URI."""
return self._equipped_api.get(uri)
Expand Down Expand Up @@ -340,3 +357,20 @@ def get_home_timeline(
"""Get the home timeline."""
return self._equipped_api.get_home_timeline(limit)

class FederationInterfaceManager:
"""Manage federation interfaces."""

def __init__(self) -> None:
"""Initialize the manager."""
self._interfaces: dict[str, FederationInterface] = {}

def get_interface(self, domain: str) -> FederationInterface:
"""Get an interface."""
if domain not in self._interfaces:
self._interfaces[domain] = FederationInterface(domain)
return self._interfaces[domain]

async def cleanup(self) -> None:
"""Clean up."""
for interface in self._interfaces.values():
await interface.cleanup()
52 changes: 20 additions & 32 deletions fedifetcher/find_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typing import TYPE_CHECKING

from fedifetcher import getter_wrappers, parsers
from fedifetcher.api.api import ApiError
from fedifetcher.api.mastodon import api_mastodon
from fedifetcher.api.api import API, ApiError
from fedifetcher.api.postgresql import PostgreSQLUpdater

if TYPE_CHECKING:
Expand All @@ -16,10 +15,11 @@

async def add_post_with_context(
post: dict[str, str],
home_server: str,
access_token: str,
api: API,
# home_server: str,
# access_token: str,
external_tokens: dict[str, str],
pgupdater: PostgreSQLUpdater,
# pgupdater: PostgreSQLUpdater,
arguments: Namespace,
) -> bool:
"""Add the given post to the server.
Expand All @@ -38,11 +38,7 @@ async def add_post_with_context(
bool: True if the post was added successfully, False otherwise.
"""
try:
await api_mastodon.Mastodon(
home_server,
access_token,
pgupdater,
).get(post["url"])
await api.get(post["url"])
if ("replies_count" in post or "in_reply_to_id" in post) and getattr(
arguments,
"backfill_with_context",
Expand All @@ -51,42 +47,37 @@ async def add_post_with_context(
parsed_urls: dict[str, tuple[str | None, str | None]] = {}
parsed = parsers.post(post["url"], parsed_urls)
if parsed is not None and parsed[0] is not None:
if api.client.pgupdater is None or api.client.token is None:
msg = "No PostgreSQL updater, or no token"
raise ApiError(msg)
known_context_urls = await getter_wrappers.get_all_known_context_urls(
home_server,
api,
[post],
parsed_urls,
external_tokens,
pgupdater,
access_token,
)
(
await add_context_urls_wrapper(
home_server,
access_token,
api,
known_context_urls,
pgupdater,
)
)
return True
except ApiError:
logging.debug(f"Failed to add {post['url']} to {home_server}")
logging.debug(f"Failed to add {post['url']} to {api.client.api_base_url}")
return False


async def add_context_urls_wrapper(
home_server: str,
access_token: str,
api: API,
context_urls: Iterable[str],
pgupdater: PostgreSQLUpdater,
) -> None:
"""Add the given toot URLs to the server.
Args:
----
home_server: The server to add the toots to.
access_token: The access token to use to add the toots.
api: The API to use to add the statuses.
context_urls: The list of toot URLs to add.
pgupdater: The PostgreSQL updater.
"""
list_of_context_urls = list(context_urls)
if len(list_of_context_urls) == 0:
Expand All @@ -96,7 +87,10 @@ async def add_context_urls_wrapper(
failed = 0
already_added = 0
posts_to_fetch = []
cached_posts: dict[str, Status | None] = pgupdater.get_dict_from_cache(
if api.client.pgupdater is None:
msg = "No PostgreSQL updater"
raise ApiError(msg)
cached_posts: dict[str, Status | None] = api.client.pgupdater.get_dict_from_cache(
list_of_context_urls,
)
logging.debug(f"Got {len(cached_posts)} cached posts")
Expand All @@ -116,24 +110,18 @@ async def add_context_urls_wrapper(

if posts_to_fetch:
logging.debug(f"Fetching {len(posts_to_fetch)} posts")
max_concurrent_tasks = 10
semaphore = asyncio.Semaphore(max_concurrent_tasks)
tasks = []
for url in posts_to_fetch:
logging.debug(f"Adding {url} to home server")
tasks.append(
api_mastodon.Mastodon(
home_server,
access_token,
pgupdater,
).get(url, semaphore),
api.get(url),
)
for task in asyncio.as_completed(tasks):
try:
result: Status = await task
logging.debug(f"Got {result}")
count += 1
pgupdater.cache_status(result)
api.client.pgupdater.cache_status(result)
except ApiError:
failed += 1

Expand Down
60 changes: 31 additions & 29 deletions fedifetcher/find_trending_posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Callable

from fedifetcher import parsers
from fedifetcher.api.api import API
from fedifetcher.api.mastodon import api_mastodon
from fedifetcher.api.postgresql import PostgreSQLUpdater

Expand Down Expand Up @@ -144,11 +145,9 @@ def get_domains_to_fetch(self) -> list[str]:


async def find_trending_posts( # noqa: C901
home_server: str,
home_token: str,
api: API,
external_feeds: list[str],
external_tokens: dict[str, str],
pgupdater: PostgreSQLUpdater,
) -> list[dict[str, str]]:
"""Pull trending posts from a list of Mastodon servers, using tokens."""
msg = f"Finding trending posts from {len(external_feeds)} domains:"
Expand Down Expand Up @@ -220,51 +219,54 @@ async def find_trending_posts( # noqa: C901
)
logging.info(f"\033[1;34m{msg}\033[0m")
max_concurrent_requests = 10
semaphore = asyncio.Semaphore(max_concurrent_requests)
promises_container = {}
promises = []
for status_id in remember_to_find_me[fetch_domain]:
if (
str(status_id) not in trending_posts_dict
or "original" not in trending_posts_dict[str(status_id)]
):
promise = asyncio.ensure_future(
api_mastodon.Mastodon(
async with asyncio.Semaphore(max_concurrent_requests) as semaphore:
for status_id in remember_to_find_me[fetch_domain]:
if (
str(status_id) not in trending_posts_dict
or "original" not in trending_posts_dict[str(status_id)]
):
async with api_mastodon.Mastodon(
fetch_domain,
external_tokens.get(fetch_domain),
).get_status_by_id(status_id, semaphore),
)
promises_container[status_id] = promise
promises.append(promise)
await asyncio.gather(*promises)
for _status_id, future in promises_container.items():
original_post = future.result()
if original_post:
add_post_to_dict(original_post, fetch_domain, trending_posts_dict)
else:
logging.warning(f"Couldn't find {_status_id} from {fetch_domain}")
) as api:
promise = asyncio.create_task(api.get_status_by_id(status_id, semaphore))
promises_container[status_id] = promise
promises.append(promise)
await asyncio.gather(*promises)
for _status_id, future in promises_container.items():
original_post = future.result()
if original_post:
add_post_to_dict(original_post, fetch_domain, trending_posts_dict)
else:
logging.warning(f"Couldn't find {_status_id} from {fetch_domain}")

logging.info(f"Fetching aux posts from {len(trending_posts_dict.keys())} domains")
await aux_domain_fetcher.do_aux_fetches(trending_posts_dict, pgupdater)
if api.client.token is None or api.client.pgupdater is None:
msg = "No token or PostgreSQL updater"
raise api_mastodon.ApiError(msg)

await aux_domain_fetcher.do_aux_fetches(trending_posts_dict, api.client.pgupdater)

updated_trending_posts_dict = await update_local_status_ids(
trending_posts_dict,
home_server,
home_token,
pgupdater,
api.client.api_base_url,
api.client.token,
api.client.pgupdater,
)

"""Update the status stats with the trending posts."""
if pgupdater:
if api.client.pgupdater is not None:
for post in updated_trending_posts_dict.values():
local_status_id = post.get("local_status_id")
if local_status_id:
pgupdater.queue_status_update(
api.client.pgupdater.queue_status_update(
local_status_id,
int(post["reblogs_count"]),
int(post["favourites_count"]),
)
pgupdater.commit_status_updates()
api.client.pgupdater.commit_status_updates()

return list(
api_mastodon.filter_language(updated_trending_posts_dict.values(), "en"),
Expand Down
17 changes: 10 additions & 7 deletions fedifetcher/getter_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, cast

import requests
from fedifetcher.api.api import API

from fedifetcher.api.mastodon import api_mastodon
from fedifetcher.api.postgresql import PostgreSQLUpdater
Expand Down Expand Up @@ -45,7 +46,8 @@ async def get_notification_users(
since = datetime.now(datetime.now(UTC).astimezone().tzinfo) - timedelta(
hours=max_age,
)
notifications = await api_mastodon.Mastodon(server, access_token).get_notifications(
async with api_mastodon.Mastodon(server, access_token) as api:
notifications = await api.get_notifications(
since_id=str(int(since.timestamp())),
)
notification_users = []
Expand Down Expand Up @@ -233,12 +235,10 @@ async def get_all_reply_toots(


async def get_all_known_context_urls(
home_server: str,
api: API,
reply_toots: list[dict[str, str]],
parsed_urls: dict[str, tuple[str | None, str | None]],
external_tokens: dict[str, str],
pgupdater: PostgreSQLUpdater,
home_server_token: str,
) -> Iterable[str]:
known_context_urls = []
if reply_toots is not None:
Expand All @@ -255,16 +255,19 @@ async def get_all_known_context_urls(
logging.debug(f"Getting context for {len(toots_to_get_context_for)} toots")
max_concurrent_tasks = 10
semaphore = asyncio.Semaphore(max_concurrent_tasks)
if api.client.pgupdater is None or api.client.token is None:
msg = "No PostgreSQL updater, or no token"
raise api_mastodon.ApiError(msg)
tasks = [
asyncio.ensure_future(
post_content(
x[0][0],
x[0][1],
x[1],
external_tokens,
pgupdater,
home_server,
home_server_token,
api.client.pgupdater,
api.client.api_base_url,
api.client.token,
semaphore,
),
)
Expand Down

0 comments on commit 83d4998

Please sign in to comment.