Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to synapse.api. #11109

Merged
merged 10 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11109.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.api` module.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.api.*]
disallow_untyped_defs = True

[mypy-synapse.events.*]
disallow_untyped_defs = True

Expand Down
14 changes: 11 additions & 3 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def get_user_by_req(

async def validate_appservice_can_control_user_id(
self, app_service: ApplicationService, user_id: str
):
) -> None:
"""Validates that the app service is allowed to control
the given user.

Expand Down Expand Up @@ -618,5 +618,13 @@ async def check_user_in_room_or_world_readable(
% (user_id, room_id),
)

async def check_auth_blocking(self, *args, **kwargs) -> None:
await self._auth_blocking.check_auth_blocking(*args, **kwargs)
async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)
Comment on lines +621 to +630
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: looks like these are lifted from AuthBlocking.check_auth_blocking

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are. It is annoying to duplicate them, but couldn't come up with a better way of having the type hints be correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ParamSpec, but that's not yet supported by mypy. I don't think this is the worst duplication though---at least mypy will spot any incompatibilities here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could help, I really want to be able to say "this just proxies through and using the same types as over there". But I don't really expect the signature of this method to change often so I think it is OK as is.

69 changes: 23 additions & 46 deletions synapse/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from twisted.web import http

Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN):
super().__init__(code, msg)
self.errcode = errcode

def error_dict(self):
def error_dict(self) -> "JsonDict":
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
return cs_error(self.msg, self.errcode)


Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
else:
self._additional_fields = dict(additional_fields)

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)


Expand All @@ -196,7 +196,7 @@ def __init__(self, msg: str, consent_uri: str):
)
self._consent_uri = consent_uri

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)


Expand Down Expand Up @@ -262,14 +262,10 @@ def __init__(self, session_id: str, result: "JsonDict"):
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super().__init__(400, message, **kwargs)
def __init__(
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
):
super().__init__(400, msg, errcode)


class NotFoundError(SynapseError):
Expand All @@ -284,10 +280,8 @@ class AuthError(SynapseError):
other poorly-defined times.
"""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(code, msg, errcode)


class InvalidClientCredentialsError(SynapseError):
Expand Down Expand Up @@ -321,7 +315,7 @@ def __init__(
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout

def error_dict(self):
def error_dict(self) -> "JsonDict":
d = super().error_dict()
d["soft_logout"] = self._soft_logout
return d
Expand All @@ -345,7 +339,7 @@ def __init__(
self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode)

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(
self.msg,
self.errcode,
Expand All @@ -357,32 +351,17 @@ def error_dict(self):
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
super().__init__(413, *args, **kwargs)


class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
super().__init__(*args, **kwargs)
def __init__(self, msg: str):
super().__init__(413, msg, Codes.TOO_LARGE)


class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""

pass


class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""

pass


class InvalidCaptchaError(SynapseError):
def __init__(
Expand All @@ -395,7 +374,7 @@ def __init__(
super().__init__(code, msg, errcode)
self.error_url = error_url

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url)


Expand All @@ -412,7 +391,7 @@ def __init__(
super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)


Expand Down Expand Up @@ -443,10 +422,8 @@ def __init__(self, msg: str = "Homeserver does not support this room version"):
class ThreepidValidationError(SynapseError):
"""An error raised when there was a problem authorising an event."""

def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(400, msg, errcode)


class IncompatibleRoomVersionError(SynapseError):
Expand All @@ -466,7 +443,7 @@ def __init__(self, room_version: str):

self._room_version = room_version

def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version)


Expand Down Expand Up @@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
errors (like programming errors).
"""

def __init__(self, inner_exception, can_retry):
def __init__(self, inner_exception: BaseException, can_retry: bool):
super().__init__(
"Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception)
Expand All @@ -503,7 +480,7 @@ def __init__(self, inner_exception, can_retry):
self.can_retry = can_retry


def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server
interactions.

Expand Down Expand Up @@ -551,7 +528,7 @@ def __init__(
msg = "%s %s: %s" % (level, code, reason)
super().__init__(msg)

def get_dict(self):
def get_dict(self) -> "JsonDict":
return {
"level": self.level,
"code": self.code,
Expand Down Expand Up @@ -580,7 +557,7 @@ def __init__(self, code: int, msg: str, response: bytes):
super().__init__(code, msg)
self.response = response

def to_synapse_error(self):
def to_synapse_error(self) -> SynapseError:
"""Make a SynapseError based on an HTTPResponseException

This is useful when a proxied request has failed, and we need to
Expand Down
18 changes: 9 additions & 9 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,24 @@ def lazy_load_members(self) -> bool:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()

def filter_presence(self, events):
def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events)

def filter_account_data(self, events):
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events)

def filter_room_state(self, events):
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events))

def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))

def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))

def filter_room_account_data(
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_account_data.filter(self._room_filter.filter(events))

def blocks_all_presence(self) -> bool:
Expand Down Expand Up @@ -309,7 +309,7 @@ def check(self, event: FilterEvent) -> bool:
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
sender = event.user_id
sender: Optional[str] = event.user_id
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
room_id = None
ev_type = "m.presence"
contains_url = False
Expand Down
51 changes: 25 additions & 26 deletions synapse/api/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from typing import Any, Optional

import attr

from synapse.api.constants import PresenceState
from synapse.types import JsonDict


class UserPresenceState(
namedtuple(
"UserPresenceState",
(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPresenceState:
"""Represents the current presence state of the user.

user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
user_id
last_active: Time in msec that the user last interacted with server.
last_federation_update: Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
last_user_sync: Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
status_msg: User set status message.
"""

def as_dict(self):
return dict(self._asdict())
user_id: str
state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: Optional[str]
currently_active: bool

def as_dict(self) -> JsonDict:
return attr.asdict(self)

@staticmethod
def from_dict(d):
def from_dict(d: JsonDict) -> "UserPresenceState":
return UserPresenceState(**d)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return attr.evolve(self, **kwargs)

@classmethod
def default(cls, user_id):
def default(cls, user_id: str) -> "UserPresenceState":
"""Returns a default presence state."""
return cls(
user_id=user_id,
Expand Down
4 changes: 2 additions & 2 deletions synapse/api/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def can_do_action(

return allowed, time_allowed

def _prune_message_counts(self, time_now_s: float):
def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit

Expand Down Expand Up @@ -190,7 +190,7 @@ async def ratelimit(
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
):
) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError

Checks if the user has ratelimiting disabled in the database by looking
Expand Down
Loading