Skip to content

Commit

Permalink
refactor: Add types to some tests (#2769)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
  • Loading branch information
marcuslimdw and provinzkraut authored Sep 15, 2024
1 parent f87b106 commit 370f39f
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

if TYPE_CHECKING:
from litestar.connection import Request
from litestar.datastructures.state import State


def router_first_dependency() -> bool:
Expand All @@ -21,12 +22,12 @@ async def router_second_dependency() -> bool:
return False


def controller_first_dependency(headers: Dict[str, Any]) -> dict:
def controller_first_dependency(headers: Dict[str, Any]) -> Dict[Any, Any]:
assert headers
return {}


async def controller_second_dependency(request: "Request") -> dict:
async def controller_second_dependency(request: "Request[Any, Any, State]") -> Dict[Any, Any]:
assert request
await sleep(0)
return {}
Expand Down Expand Up @@ -60,7 +61,7 @@ class FirstController(Controller):
"first": Provide(local_method_first_dependency, sync_to_thread=False),
},
)
def test_method(self, first: int, second: dict, third: bool) -> None:
def test_method(self, first: int, second: Dict[Any, Any], third: bool) -> None:
assert isinstance(first, int)
assert isinstance(second, dict)
assert not third
Expand Down Expand Up @@ -109,7 +110,7 @@ class SecondController(Controller):
path = "/second"

@get()
def test_method(self, first: dict) -> None:
def test_method(self, first: Dict[Any, Any]) -> None:
pass

with create_test_client([first_controller, SecondController]) as client:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from litestar import Controller, websocket
from litestar.connection import WebSocket
from litestar.datastructures import State
from litestar.di import Provide
from litestar.exceptions import WebSocketDisconnect
from litestar.testing import create_test_client
Expand All @@ -19,12 +20,12 @@ async def router_second_dependency() -> bool:
return False


def controller_first_dependency(headers: Dict[str, Any]) -> dict:
def controller_first_dependency(headers: Dict[str, Any]) -> Dict[Any, Any]:
assert headers
return {}


async def controller_second_dependency(socket: WebSocket) -> dict:
async def controller_second_dependency(socket: WebSocket[Any, Any, Any]) -> Dict[Any, Any]:
assert socket
await sleep(0)
return {}
Expand Down Expand Up @@ -56,7 +57,9 @@ class FirstController(Controller):
"first": Provide(local_method_first_dependency, sync_to_thread=False),
},
)
async def test_method(self, socket: WebSocket, first: int, second: dict, third: bool) -> None:
async def test_method(
self, socket: WebSocket[Any, Any, Any], first: int, second: Dict[Any, Any], third: bool
) -> None:
await socket.accept()
msg = await socket.receive_json()
assert msg
Expand Down Expand Up @@ -87,7 +90,7 @@ def test_function_dependency_injection() -> None:
"third": Provide(local_method_second_dependency, sync_to_thread=False),
},
)
async def test_function(socket: WebSocket, first: int, second: bool, third: str) -> None:
async def test_function(socket: WebSocket[Any, Any, State], first: int, second: bool, third: str) -> None:
await socket.accept()
assert socket
msg = await socket.receive_json()
Expand All @@ -113,7 +116,7 @@ class SecondController(Controller):
path = "/second"

@websocket()
async def test_method(self, socket: WebSocket, first: dict) -> None:
async def test_method(self, socket: WebSocket[Any, Any, Any], _: Dict[Any, Any]) -> None:
await socket.accept()

client = create_test_client([FirstController, SecondController])
Expand Down
22 changes: 12 additions & 10 deletions tests/e2e/test_life_cycle_hooks/test_after_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional

import pytest

Expand All @@ -8,19 +8,19 @@
from litestar.types import AfterRequestHookHandler


def sync_after_request_handler(response: Response) -> Response:
def sync_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]:
assert isinstance(response, Response)
response.content = {"hello": "moon"}
return response


async def async_after_request_handler(response: Response) -> Response:
async def async_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]:
assert isinstance(response, Response)
response.content = {"hello": "moon"}
return response


async def async_after_request_handler_with_hello_world(response: Response) -> Response:
async def async_after_request_handler_with_hello_world(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]:
assert isinstance(response, Response)
response.content = {"hello": "world"}
return response
Expand All @@ -34,9 +34,11 @@ async def async_after_request_handler_with_hello_world(response: Response) -> Re
[async_after_request_handler, {"hello": "moon"}],
],
)
def test_after_request_handler_called(after_request: Optional[AfterRequestHookHandler], expected: dict) -> None:
def test_after_request_handler_called(
after_request: Optional[AfterRequestHookHandler], expected: Dict[str, str]
) -> None:
@get(after_request=after_request)
def handler() -> dict:
def handler() -> Dict[str, str]:
return {"hello": "world"}

with create_test_client(route_handlers=handler) as client:
Expand All @@ -63,15 +65,15 @@ def test_after_request_handler_resolution(
router_after_request_handler: Optional[AfterRequestHookHandler],
controller_after_request_handler: Optional[AfterRequestHookHandler],
method_after_request_handler: Optional[AfterRequestHookHandler],
expected: dict,
expected: Dict[str, str],
) -> None:
class MyController(Controller):
path = "/hello"

after_request = controller_after_request_handler

@get(after_request=method_after_request_handler)
def hello(self) -> dict:
def hello(self) -> Dict[str, str]:
return {"hello": "world"}

router = Router(path="/greetings", route_handlers=[MyController], after_request=router_after_request_handler)
Expand All @@ -82,12 +84,12 @@ def hello(self) -> dict:


def test_after_request_handles_handlers_that_return_responses() -> None:
def after_request(response: Response) -> Response:
def after_request(response: Response[Any]) -> Response[Any]:
response.headers["Custom-Header-Name"] = "Custom Header Value"
return response

@get("/")
def handler() -> Response:
def handler() -> Response[str]:
return Response("test")

with create_test_client(handler, after_request=after_request) as client:
Expand Down
23 changes: 12 additions & 11 deletions tests/e2e/test_life_cycle_hooks/test_before_request.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
from typing import Optional
from typing import Any, Dict, Optional

import pytest

from litestar import Controller, Request, Response, Router, get
from litestar.datastructures import State
from litestar.testing import create_test_client
from litestar.types import AnyCallable, BeforeRequestHookHandler


def sync_before_request_handler_with_return_value(request: Request) -> dict:
def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
assert isinstance(request, Request)
return {"hello": "moon"}


async def async_before_request_handler_with_return_value(request: Request) -> dict:
async def async_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
assert isinstance(request, Request)
return {"hello": "moon"}


def sync_before_request_handler_without_return_value(request: Request) -> None:
def sync_before_request_handler_without_return_value(request: Request[Any, Any, State]) -> None:
assert isinstance(request, Request)


async def async_before_request_handler_without_return_value(request: Request) -> None:
async def async_before_request_handler_without_return_value(request: Request[Any, Any, State]) -> None:
assert isinstance(request, Request)


def sync_after_request_handler(response: Response) -> Response:
def sync_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]:
assert isinstance(response, Response)
response.content = {"hello": "moon"}
return response


async def async_after_request_handler(response: Response) -> Response:
async def async_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]:
assert isinstance(response, Response)
response.content = {"hello": "moon"}
return response
Expand All @@ -47,9 +48,9 @@ async def async_after_request_handler(response: Response) -> Response:
(async_before_request_handler_without_return_value, {"hello": "world"}),
),
)
def test_before_request_handler_called(before_request: Optional[AnyCallable], expected: dict) -> None:
def test_before_request_handler_called(before_request: Optional[AnyCallable], expected: Dict[str, str]) -> None:
@get(before_request=before_request)
def handler() -> dict:
def handler() -> Dict[str, str]:
return {"hello": "world"}

with create_test_client(route_handlers=handler) as client:
Expand Down Expand Up @@ -94,15 +95,15 @@ def test_before_request_handler_resolution(
router_before_request_handler: Optional[BeforeRequestHookHandler],
controller_before_request_handler: Optional[BeforeRequestHookHandler],
method_before_request_handler: Optional[BeforeRequestHookHandler],
expected: dict,
expected: Dict[str, str],
) -> None:
class MyController(Controller):
path = "/hello"

before_request = controller_before_request_handler

@get(before_request=method_before_request_handler)
def hello(self) -> dict:
def hello(self) -> Dict[str, str]:
return {"hello": "world"}

router = Router(path="/greetings", route_handlers=[MyController], before_request=router_before_request_handler)
Expand Down
11 changes: 7 additions & 4 deletions tests/e2e/test_response_caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gzip
import random
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from unittest.mock import MagicMock
from uuid import uuid4

Expand All @@ -11,6 +11,7 @@
from litestar import Litestar, Request, Response, get, post
from litestar.config.compression import CompressionConfig
from litestar.config.response_cache import CACHE_FOREVER, ResponseCacheConfig
from litestar.datastructures import State
from litestar.enums import CompressionEncoding
from litestar.middleware.response_cache import ResponseCacheMiddleware
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
Expand All @@ -22,13 +23,15 @@
if TYPE_CHECKING:
from time_machine import Coordinates

T = TypeVar("T")


@pytest.fixture()
def mock() -> MagicMock:
return MagicMock(return_value=str(random.random()))


def after_request_handler(response: "Response") -> "Response":
def after_request_handler(response: "Response[T]") -> "Response[T]":
response.headers["unique-identifier"] = str(uuid4())
return response

Expand Down Expand Up @@ -132,7 +135,7 @@ async def handler() -> None:

@pytest.mark.parametrize("sync_to_thread", (True, False))
async def test_custom_cache_key(sync_to_thread: bool, anyio_backend: str, mock: MagicMock) -> None:
def custom_cache_key_builder(request: Request) -> str:
def custom_cache_key_builder(request: Request[Any, Any, State]) -> str:
return f"{request.url.path}:::cached"

@get("/cached", sync_to_thread=sync_to_thread, cache=True, cache_key_builder=custom_cache_key_builder)
Expand Down Expand Up @@ -262,7 +265,7 @@ def test_default_do_response_cache_predicate(
mock: MagicMock, response: Union[int, Type[RuntimeError]], should_cache: bool
) -> None:
@get("/", cache=True)
def handler() -> Response:
def handler() -> Response[None]:
mock()
if isinstance(response, int):
return Response(None, status_code=response)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import fields
from typing import TYPE_CHECKING, Callable, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from unittest.mock import MagicMock, Mock, PropertyMock

import pytest
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_using_custom_http_exception_handler() -> None:
@get("/{param:int}")
def my_route_handler(param: int) -> None: ...

def my_custom_handler(_: Request, __: Exception) -> Response:
def my_custom_handler(_: Request[Any, Any, State], __: Exception) -> Response[str]:
return Response(content="custom message", media_type=MediaType.TEXT, status_code=HTTP_400_BAD_REQUEST)

with create_test_client(my_route_handler, exception_handlers={NotFoundException: my_custom_handler}) as client:
Expand Down
Loading

0 comments on commit 370f39f

Please sign in to comment.