Skip to content

Commit

Permalink
Revamp Transport so they always build a full Response object (fastapi…
Browse files Browse the repository at this point in the history
…-users#1049)

* Revamp Transport so they always build a full Response object

* Fix linting

* Add private methods to set cookies on CookieTransport

* Change on_after_login login_return parameter to response
  • Loading branch information
frankie567 authored Apr 27, 2023
1 parent 9a2515f commit 8fd097c
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 71 deletions.
2 changes: 1 addition & 1 deletion docs/configuration/user-manager.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self,
user: User,
request: Optional[Request] = None,
login_return: Optional[Any] = None,
response: Optional[Response] = None,
):
print(f"User {user.id} logged in.")
```
Expand Down
27 changes: 11 additions & 16 deletions fastapi_users/authentication/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Generic
from typing import Generic

from fastapi import Response
from fastapi import Response, status

from fastapi_users import models
from fastapi_users.authentication.strategy import (
Expand Down Expand Up @@ -40,27 +40,22 @@ def __init__(
self.get_strategy = get_strategy

async def login(
self,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
response: Response,
) -> Any:
self, strategy: Strategy[models.UP, models.ID], user: models.UP
) -> Response:
token = await strategy.write_token(user)
return await self.transport.get_login_response(token, response)
return await self.transport.get_login_response(token)

async def logout(
self,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
token: str,
response: Response,
) -> Any:
self, strategy: Strategy[models.UP, models.ID], user: models.UP, token: str
) -> Response:
try:
await strategy.destroy_token(token, user)
except StrategyDestroyNotSupportedError:
pass

try:
await self.transport.get_logout_response(response)
response = await self.transport.get_logout_response()
except TransportLogoutNotSupportedError:
return None
response = Response(status_code=status.HTTP_204_NO_CONTENT)

return response
5 changes: 2 additions & 3 deletions fastapi_users/authentication/transport/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from typing import Any

if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
Expand All @@ -19,10 +18,10 @@ class TransportLogoutNotSupportedError(Exception):
class Transport(Protocol):
scheme: SecurityBase

async def get_login_response(self, token: str, response: Response) -> Any:
async def get_login_response(self, token: str) -> Response:
... # pragma: no cover

async def get_logout_response(self, response: Response) -> Any:
async def get_logout_response(self) -> Response:
... # pragma: no cover

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions fastapi_users/authentication/transport/bearer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any

from fastapi import Response, status
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel

Expand All @@ -22,10 +21,11 @@ class BearerTransport(Transport):
def __init__(self, tokenUrl: str):
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)

async def get_login_response(self, token: str, response: Response) -> Any:
return BearerResponse(access_token=token, token_type="bearer")
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
return JSONResponse(bearer_response.dict())

async def get_logout_response(self, response: Response) -> Any:
async def get_logout_response(self) -> Response:
raise TransportLogoutNotSupportedError()

@staticmethod
Expand Down
24 changes: 15 additions & 9 deletions fastapi_users/authentication/transport/cookie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Optional
from typing import Optional

if sys.version_info < (3, 8):
from typing_extensions import Literal # pragma: no cover
Expand Down Expand Up @@ -35,7 +35,15 @@ def __init__(
self.cookie_samesite = cookie_samesite
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)

async def get_login_response(self, token: str, response: Response) -> Any:
async def get_login_response(self, token: str) -> Response:
response = Response(status_code=status.HTTP_204_NO_CONTENT)
return self._set_login_cookie(response, token)

async def get_logout_response(self) -> Response:
response = Response(status_code=status.HTTP_204_NO_CONTENT)
return self._set_logout_cookie(response)

def _set_login_cookie(self, response: Response, token: str) -> Response:
response.set_cookie(
self.cookie_name,
token,
Expand All @@ -46,12 +54,9 @@ async def get_login_response(self, token: str, response: Response) -> Any:
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
return response

# We shouldn't return directly the response
# so that FastAPI can terminate it properly
return None

async def get_logout_response(self, response: Response) -> Any:
def _set_logout_cookie(self, response: Response) -> Response:
response.set_cookie(
self.cookie_name,
"",
Expand All @@ -62,11 +67,12 @@ async def get_logout_response(self, response: Response) -> Any:
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
return response

@staticmethod
def get_openapi_login_responses_success() -> OpenAPIResponseType:
return {status.HTTP_200_OK: {"model": None}}
return {status.HTTP_204_NO_CONTENT: {"model": None}}

@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
return {status.HTTP_200_OK: {"model": None}}
return {status.HTTP_204_NO_CONTENT: {"model": None}}
8 changes: 4 additions & 4 deletions fastapi_users/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, Generic, Optional, Union

import jwt
from fastapi import Request
from fastapi import Request, Response
from fastapi.security import OAuth2PasswordRequestForm

from fastapi_users import exceptions, models, schemas
Expand Down Expand Up @@ -589,7 +589,7 @@ async def on_after_login(
self,
user: models.UP,
request: Optional[Request] = None,
login_return: Optional[Any] = None,
response: Optional[Response] = None,
) -> None:
"""
Perform logic after user login.
Expand All @@ -598,8 +598,8 @@ async def on_after_login(
:param user: The user that is logging in
:param request: Optional FastAPI request
:param login_return: Optional return of the login
triggered the operation, defaults to None.
:param response: Optional response built by the transport.
Defaults to None
"""
return # pragma: no cover

Expand Down
12 changes: 5 additions & 7 deletions fastapi_users/router/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple

from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm

from fastapi_users import models
Expand Down Expand Up @@ -50,7 +50,6 @@ def get_auth_router(
)
async def login(
request: Request,
response: Response,
credentials: OAuth2PasswordRequestForm = Depends(),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
Expand All @@ -67,9 +66,9 @@ async def login(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_USER_NOT_VERIFIED,
)
login_return = await backend.login(strategy, user, response)
await user_manager.on_after_login(user, request, login_return)
return login_return
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response

logout_responses: OpenAPIResponseType = {
**{
Expand All @@ -84,11 +83,10 @@ async def login(
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
)
async def logout(
response: Response,
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user, token = user_token
return await backend.logout(strategy, user, token, response)
return await backend.logout(strategy, user, token)

return router
9 changes: 4 additions & 5 deletions fastapi_users/router/oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Optional, Tuple, Type

import jwt
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
from pydantic import BaseModel
Expand Down Expand Up @@ -100,7 +100,6 @@ async def authorize(
)
async def callback(
request: Request,
response: Response,
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
Expand Down Expand Up @@ -148,9 +147,9 @@ async def callback(
)

# Authenticate
login_return = await backend.login(strategy, user, response)
await user_manager.on_after_login(user, request, login_return)
return login_return
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response

return router

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ class MockTransport(BearerTransport):
def __init__(self, tokenUrl: str):
super().__init__(tokenUrl)

async def get_logout_response(self, response: Response) -> Any:
return None
async def get_logout_response(self) -> Any:
return Response()

@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_authentication_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def backend(
@pytest.mark.authentication
async def test_logout(backend: AuthenticationBackend, user: UserModel):
strategy = cast(Strategy, backend.get_strategy())
result = await backend.logout(strategy, user, "TOKEN", Response())
assert result is None
result = await backend.logout(strategy, user, "TOKEN")
assert isinstance(result, Response)
15 changes: 6 additions & 9 deletions tests/test_authentication_transport_bearer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from fastapi import Response, status
from fastapi import status
from fastapi.responses import JSONResponse

from fastapi_users.authentication.transport import (
BearerTransport,
Expand All @@ -16,21 +17,17 @@ def bearer_transport() -> BearerTransport:
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(bearer_transport: BearerTransport):
response = Response()
login_response = await bearer_transport.get_login_response("TOKEN", response)
response = await bearer_transport.get_login_response("TOKEN")

assert isinstance(login_response, BearerResponse)

assert login_response.access_token == "TOKEN"
assert login_response.token_type == "bearer"
assert isinstance(response, JSONResponse)
assert response.body == b'{"access_token":"TOKEN","token_type":"bearer"}'


@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_logout_response(bearer_transport: BearerTransport):
response = Response()
with pytest.raises(TransportLogoutNotSupportedError):
await bearer_transport.get_logout_response(response)
await bearer_transport.get_logout_response()


@pytest.mark.authentication
Expand Down
16 changes: 8 additions & 8 deletions tests/test_authentication_transport_cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
secure = cookie_transport.cookie_secure
httponly = cookie_transport.cookie_httponly

response = Response()
login_response = await cookie_transport.get_login_response("TOKEN", response)
response = await cookie_transport.get_login_response("TOKEN")

assert login_response is None
assert isinstance(response, Response)
assert response.status_code == status.HTTP_204_NO_CONTENT

cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
assert len(cookies) == 1
Expand Down Expand Up @@ -79,10 +79,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_logout_response(cookie_transport: CookieTransport):
response = Response()
logout_response = await cookie_transport.get_logout_response(response)
response = await cookie_transport.get_logout_response()

assert logout_response is None
assert isinstance(response, Response)
assert response.status_code == status.HTTP_204_NO_CONTENT

cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
assert len(cookies) == 1
Expand All @@ -96,13 +96,13 @@ async def test_get_logout_response(cookie_transport: CookieTransport):
@pytest.mark.openapi
def test_get_openapi_login_responses_success(cookie_transport: CookieTransport):
assert cookie_transport.get_openapi_login_responses_success() == {
status.HTTP_200_OK: {"model": None}
status.HTTP_204_NO_CONTENT: {"model": None}
}


@pytest.mark.authentication
@pytest.mark.openapi
def test_get_openapi_logout_responses_success(cookie_transport: CookieTransport):
assert cookie_transport.get_openapi_logout_responses_success() == {
status.HTTP_200_OK: {"model": None}
status.HTTP_204_NO_CONTENT: {"model": None}
}

0 comments on commit 8fd097c

Please sign in to comment.