Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Do not convert async functions to Deferreds in the interactive_auth_handler #7944

Merged
merged 2 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7944.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the interactive_auth_handler wrapper to async/await.
47 changes: 21 additions & 26 deletions synapse/rest/client/v2_alpha/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
"""
import logging
import re

from twisted.internet import defer
from typing import Iterable, Pattern

from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
Expand All @@ -27,15 +26,23 @@
logger = logging.getLogger(__name__)


def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
def client_patterns(
path_regex: str,
releases: Iterable[int] = (0,),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.

Args:
path_regex (str): The regex string to match. This should NOT have a ^
path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
releases: An iterable of releases to include this endpoint under.
unstable: If true, include this endpoint under the "unstable" prefix.
v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
SRE_Pattern
An iterable of patterns.
"""
patterns = []

Expand Down Expand Up @@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors

Takes a on_POST method which returns a deferred (errcode, body) response
Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.

Normal usage is:

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
async def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
await self.auth_handler.check_auth
"""

def wrapped(*args, **kwargs):
res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
async def wrapped(*args, **kwargs):
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result

return wrapped


def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler

Catches InteractiveAuthIncompleteErrors and turns them into 401 responses

Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result