Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to the receipts and user directory handlers. #8976

Merged
merged 5 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8976.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to the receipts and user directory handlers.
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ files =
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/receipts.py,
synapse/handlers/register.py,
synapse/handlers/room.py,
synapse/handlers/room_list.py,
Expand All @@ -52,6 +53,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sso.py,
synapse/handlers/sync.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
Expand Down
30 changes: 19 additions & 11 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.server_name = hs.config.server_name
Expand All @@ -36,7 +39,7 @@ def __init__(self, hs):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()

async def _received_remote_receipt(self, origin, content):
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
Expand All @@ -63,11 +66,11 @@ async def _received_remote_receipt(self, origin, content):

await self._handle_new_receipts(receipts)

async def _handle_new_receipts(self, receipts):
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
min_batch_id = None # type: Optional[int]
max_batch_id = None # type: Optional[int]

for receipt in receipts:
res = await self.store.insert_receipt(
Expand All @@ -89,7 +92,8 @@ async def _handle_new_receipts(self, receipts):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id

if min_batch_id is None:
# Either both of these should be None or neither.
if min_batch_id is None or max_batch_id is None:
# no new receipts
return False

Expand All @@ -103,7 +107,9 @@ async def _handle_new_receipts(self, receipts):

return True

async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str
) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
Expand All @@ -123,10 +129,12 @@ async def received_client_receipt(self, room_id, receipt_type, user_id, event_id


class ReceiptEventSource:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

async def get_new_events(self, from_key, room_ids, **kwargs):
async def get_new_events(
self, from_key: int, room_ids: List[str], **kwargs
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()

Expand Down Expand Up @@ -171,5 +179,5 @@ async def get_new_events_as(

return (events, to_key)

def get_current_key(self, direction="f"):
def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
63 changes: 42 additions & 21 deletions synapse/handlers/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
Expand All @@ -49,7 +54,7 @@ def __init__(self, hs):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
self.pos = None
self.pos = None # type: Optional[int]

# Guard to ensure we only process deltas one at a time
self._is_processing = False
Expand All @@ -61,7 +66,9 @@ def __init__(self, hs):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)

async def search_users(self, user_id, search_term, limit):
async def search_users(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory
Returns:
Expand Down Expand Up @@ -89,7 +96,7 @@ async def search_users(self, user_id, search_term, limit):

return results

def notify_new_event(self):
def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.update_user_directory:
Expand All @@ -107,7 +114,9 @@ async def process():
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)

async def handle_local_profile_change(self, user_id, profile):
async def handle_local_profile_change(
self, user_id: str, profile: ProfileInfo
) -> None:
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
Expand All @@ -124,14 +133,14 @@ async def handle_local_profile_change(self, user_id, profile):
user_id, profile.display_name, profile.avatar_url
)

async def handle_user_deactivated(self, user_id):
async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)

async def _unsafe_process(self):
async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
Expand Down Expand Up @@ -166,7 +175,7 @@ async def _unsafe_process(self):

await self.store.update_user_directory_stream_pos(max_pos)

async def _handle_deltas(self, deltas):
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process
"""
for delta in deltas:
Expand Down Expand Up @@ -236,16 +245,20 @@ async def _handle_deltas(self, deltas):
logger.debug("Ignoring irrelevant type: %r", typ)

async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ
):
self,
room_id: str,
prev_event_id: Optional[str],
event_id: Optional[str],
typ: str,
) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
room_id (str)
prev_event_id (str|None): The previous event before the state change
event_id (str|None): The new event after the state change
typ (str): Type of the event
room_id: The ID of the room which changed.
prev_event_id: The previous event before the state change
event_id: The new event after the state change
typ: Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)

Expand Down Expand Up @@ -303,12 +316,14 @@ async def _handle_room_publicity_change(
for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)

async def _handle_new_user(self, room_id, user_id, profile):
async def _handle_new_user(
self, room_id: str, user_id: str, profile: ProfileInfo
) -> None:
"""Called when we might need to add user to directory
Args:
room_id (str): room_id that user joined or started being public
user_id (str)
room_id: The room ID that user joined or started being public
user_id
"""
logger.debug("Adding new user to dir, %r", user_id)

Expand Down Expand Up @@ -356,12 +371,12 @@ async def _handle_new_user(self, room_id, user_id, profile):
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)

async def _handle_remove_user(self, room_id, user_id):
async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
Args:
room_id (str): room_id that user left or stopped being public that
user_id (str)
room_id: The room ID that user left or stopped being public that
user_id
"""
logger.debug("Removing user %r", user_id)

Expand All @@ -374,7 +389,13 @@ async def _handle_remove_user(self, room_id, user_id):
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)

async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
async def _handle_profile_change(
self,
user_id: str,
room_id: str,
prev_event_id: Optional[str],
event_id: Optional[str],
) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
"""
Expand Down