Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add types to http.site #10867

Merged
merged 3 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10867.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.http.site`.
38 changes: 21 additions & 17 deletions synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site

from synapse.config.server import ListenerConfig
Expand Down Expand Up @@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""

def __init__(self, channel, *args, max_request_body_size=1024, **kw):
def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add one for channel while we're here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OMG, channel is the worst I hate it. It's actually of type _GenericHTTPChannelProtocol, which isn't exposed anywhere

Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
Expand All @@ -83,13 +83,13 @@ def __init__(self, channel, *args, max_request_body_size=1024, **kw):
self._is_processing = False

# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
self._processing_finished_time: Optional[float] = None

# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
self.finish_time: Optional[float] = None

def __repr__(self):
def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
Expand All @@ -100,7 +100,7 @@ def __repr__(self):
self.site.site_tag,
)

def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
Expand Down Expand Up @@ -139,7 +139,7 @@ def requester(self, value: Union[Requester, str]) -> None:
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester

def get_request_id(self):
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)

def get_redacted_uri(self) -> str:
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:

return None, None

def render(self, resrc):
def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.

Expand Down Expand Up @@ -282,7 +282,7 @@ async def handle_request(request):
if self.finish_time is not None:
self._finished_processing()

def finish(self):
def finish(self) -> None:
"""Called when all response data has been written to this Request.

Overrides twisted.web.server.Request.finish to record the finish time and do
Expand All @@ -295,7 +295,7 @@ def finish(self):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()

def connectionLost(self, reason):
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.

Overrides twisted.web.server.Request.connectionLost to record the finish time and
Expand Down Expand Up @@ -327,7 +327,7 @@ def connectionLost(self, reason):
if not self._is_processing:
self._finished_processing()

def _started_processing(self, servlet_name):
def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.

This will log the request's arrival. Once the request completes,
Expand All @@ -354,9 +354,11 @@ def _started_processing(self, servlet_name):
self.get_redacted_uri(),
)

def _finished_processing(self):
def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
assert self.finish_time is not None

usage = self.logcontext.get_resource_usage()

if self._processing_finished_time is None:
Expand Down Expand Up @@ -437,15 +439,15 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False

def requestReceived(self, command, path, version):
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
# headers.
self._process_forwarded_headers()
return super().requestReceived(command, path, version)

def _process_forwarded_headers(self):
def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
Expand All @@ -470,7 +472,7 @@ def _process_forwarded_headers(self):
)
self._forwarded_https = True

def isSecure(self):
def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
Expand Down Expand Up @@ -547,12 +549,14 @@ def __init__(

def request_factory(channel, queued) -> Request:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
channel,
max_request_body_size=max_request_body_size,
queued=queued,
)

self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")

def log(self, request):
def log(self, request: SynapseRequest) -> None:
pass