Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type-hints to profile and base handlers. #8609

Merged
merged 4 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8609.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to profile and base handler.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
synapse/handlers/appservice.py,
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
Expand All @@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
Expand Down
20 changes: 10 additions & 10 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional

import synapse.state
import synapse.storage
Expand All @@ -22,6 +23,9 @@
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers.
"""

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
Expand All @@ -56,7 +56,7 @@ def __init__(self, hs):
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
) # type: Optional[Ratelimiter]
else:
self.admin_redaction_ratelimiter = None

Expand Down Expand Up @@ -127,15 +127,15 @@ async def maybe_kick_guest_users(self, event, context=None):
if guest_access != "can_join":
if context:
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
current_state_dict = await self.store.get_events(
list(current_state_ids.values())
)
current_state = list(current_state_dict.values())
else:
current_state = await self.state_handler.get_current_state(
current_state_map = await self.state_handler.get_current_state(
event.room_id
)

current_state = list(current_state.values())
current_state = list(current_state_map.values())

logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state)
Expand Down
8 changes: 6 additions & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ async def room_initial_sync(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
# to leave.
assert member_event_id

result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
Expand All @@ -315,7 +319,7 @@ async def _room_initial_sync_parted(
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
Expand Down Expand Up @@ -367,7 +371,7 @@ async def _room_initial_sync_joined(
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
Expand Down
74 changes: 49 additions & 25 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import (
AuthError,
Expand All @@ -25,10 +25,19 @@
SynapseError,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)

from ._base import BaseHandler

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

MAX_DISPLAYNAME_LEN = 256
Expand All @@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.federation = hs.get_federation_client()
Expand All @@ -60,7 +69,7 @@ def __init__(self, hs):
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)

async def get_profile(self, user_id):
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)

if self.hs.is_mine(target_user):
Expand Down Expand Up @@ -91,7 +100,7 @@ async def get_profile(self, user_id):
except HttpResponseException as e:
raise e.to_synapse_error()

async def get_profile_from_cache(self, user_id):
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
Expand All @@ -115,7 +124,7 @@ async def get_profile_from_cache(self, user_id):
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}

async def get_displayname(self, target_user):
async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
Expand Down Expand Up @@ -143,15 +152,19 @@ async def get_displayname(self, target_user):
return result["displayname"]

async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
self,
target_user: UserID,
requester: Requester,
new_displayname: str,
by_admin: bool = False,
) -> None:
"""Set the displayname of a user

Args:
target_user (UserID): the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change.
new_displayname (str): The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose displayname is to be changed.
requester: The user attempting to make this change.
new_displayname: The displayname to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand All @@ -176,16 +189,19 @@ async def set_displayname(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)

displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "":
new_displayname = None
displayname_to_set = None

# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)

await self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
)

if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
Expand All @@ -195,7 +211,7 @@ async def set_displayname(

await self._update_join_states(requester, target_user)

async def get_avatar_url(self, target_user):
async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
Expand All @@ -222,15 +238,19 @@ async def get_avatar_url(self, target_user):
return result["avatar_url"]

async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
self,
target_user: UserID,
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
):
"""Set a new avatar URL for a user.

Args:
target_user (UserID): the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose avatar URL is to be changed.
requester: The user attempting to make this change.
new_avatar_url: The avatar URL to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand Down Expand Up @@ -267,7 +287,7 @@ async def set_avatar_url(

await self._update_join_states(requester, target_user)

async def on_profile_query(self, args):
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand All @@ -292,7 +312,9 @@ async def on_profile_query(self, args):

return response

async def _update_join_states(self, requester, target_user):
async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
if not self.hs.is_mine(target_user):
return

Expand Down Expand Up @@ -323,15 +345,17 @@ async def _update_join_states(self, requester, target_user):
"Failed to update join event for room %s - %s", room_id, str(e)
)

async def check_profile_query_allowed(self, target_user, requester=None):
async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None
) -> None:
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.

Args:
target_user (UserID): The owner of the queried profile.
requester (None|UserID): The user querying for the profile.
target_user: The owner of the queried profile.
requester: The user querying for the profile.

Raises:
SynapseError(403): The two users share no room, or ne user couldn't
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -72,7 +72,7 @@ async def create_profile(self, user_localpart: str) -> None:
)

async def set_profile_displayname(
self, user_localpart: str, new_displayname: str
self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
Expand Down Expand Up @@ -144,7 +144,7 @@ async def is_subscribed_remote_profile_for_user(self, user_id):

async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""

Expand Down