From 370f39fde1a9d1b18892bfc723414b0e107d09b1 Mon Sep 17 00:00:00 2001 From: Marcus Lim <42759889+marcuslimdw@users.noreply.github.com> Date: Sun, 15 Sep 2024 18:12:14 +0800 Subject: [PATCH] refactor: Add types to some tests (#2769) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- .../test_http_handler_dependency_injection.py | 9 ++- ..._websocket_handler_dependency_injection.py | 13 +-- .../test_after_request.py | 22 ++--- .../test_before_request.py | 23 +++--- tests/e2e/test_response_caching.py | 11 ++- tests/unit/test_app.py | 4 +- tests/unit/test_connection/test_request.py | 80 +++++++++---------- tests/unit/test_connection/test_websocket.py | 63 ++++++++------- tests/unit/test_openapi/test_responses.py | 6 +- 9 files changed, 122 insertions(+), 109 deletions(-) diff --git a/tests/e2e/test_dependency_injection/test_http_handler_dependency_injection.py b/tests/e2e/test_dependency_injection/test_http_handler_dependency_injection.py index b5015f92b0..86803f0a81 100644 --- a/tests/e2e/test_dependency_injection/test_http_handler_dependency_injection.py +++ b/tests/e2e/test_dependency_injection/test_http_handler_dependency_injection.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from litestar.connection import Request + from litestar.datastructures.state import State def router_first_dependency() -> bool: @@ -21,12 +22,12 @@ async def router_second_dependency() -> bool: return False -def controller_first_dependency(headers: Dict[str, Any]) -> dict: +def controller_first_dependency(headers: Dict[str, Any]) -> Dict[Any, Any]: assert headers return {} -async def controller_second_dependency(request: "Request") -> dict: +async def controller_second_dependency(request: "Request[Any, Any, State]") -> Dict[Any, Any]: assert request await sleep(0) return {} @@ -60,7 +61,7 @@ class FirstController(Controller): "first": Provide(local_method_first_dependency, sync_to_thread=False), }, ) - def test_method(self, first: int, second: dict, third: bool) -> None: + def test_method(self, first: int, second: Dict[Any, Any], third: bool) -> None: assert isinstance(first, int) assert isinstance(second, dict) assert not third @@ -109,7 +110,7 @@ class SecondController(Controller): path = "/second" @get() - def test_method(self, first: dict) -> None: + def test_method(self, first: Dict[Any, Any]) -> None: pass with create_test_client([first_controller, SecondController]) as client: diff --git a/tests/e2e/test_dependency_injection/test_websocket_handler_dependency_injection.py b/tests/e2e/test_dependency_injection/test_websocket_handler_dependency_injection.py index 86cdcc2fc5..a95437c7fe 100644 --- a/tests/e2e/test_dependency_injection/test_websocket_handler_dependency_injection.py +++ b/tests/e2e/test_dependency_injection/test_websocket_handler_dependency_injection.py @@ -5,6 +5,7 @@ from litestar import Controller, websocket from litestar.connection import WebSocket +from litestar.datastructures import State from litestar.di import Provide from litestar.exceptions import WebSocketDisconnect from litestar.testing import create_test_client @@ -19,12 +20,12 @@ async def router_second_dependency() -> bool: return False -def controller_first_dependency(headers: Dict[str, Any]) -> dict: +def controller_first_dependency(headers: Dict[str, Any]) -> Dict[Any, Any]: assert headers return {} -async def controller_second_dependency(socket: WebSocket) -> dict: +async def controller_second_dependency(socket: WebSocket[Any, Any, Any]) -> Dict[Any, Any]: assert socket await sleep(0) return {} @@ -56,7 +57,9 @@ class FirstController(Controller): "first": Provide(local_method_first_dependency, sync_to_thread=False), }, ) - async def test_method(self, socket: WebSocket, first: int, second: dict, third: bool) -> None: + async def test_method( + self, socket: WebSocket[Any, Any, Any], first: int, second: Dict[Any, Any], third: bool + ) -> None: await socket.accept() msg = await socket.receive_json() assert msg @@ -87,7 +90,7 @@ def test_function_dependency_injection() -> None: "third": Provide(local_method_second_dependency, sync_to_thread=False), }, ) - async def test_function(socket: WebSocket, first: int, second: bool, third: str) -> None: + async def test_function(socket: WebSocket[Any, Any, State], first: int, second: bool, third: str) -> None: await socket.accept() assert socket msg = await socket.receive_json() @@ -113,7 +116,7 @@ class SecondController(Controller): path = "/second" @websocket() - async def test_method(self, socket: WebSocket, first: dict) -> None: + async def test_method(self, socket: WebSocket[Any, Any, Any], _: Dict[Any, Any]) -> None: await socket.accept() client = create_test_client([FirstController, SecondController]) diff --git a/tests/e2e/test_life_cycle_hooks/test_after_request.py b/tests/e2e/test_life_cycle_hooks/test_after_request.py index 3fec69d13e..7be3d6d522 100644 --- a/tests/e2e/test_life_cycle_hooks/test_after_request.py +++ b/tests/e2e/test_life_cycle_hooks/test_after_request.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional import pytest @@ -8,19 +8,19 @@ from litestar.types import AfterRequestHookHandler -def sync_after_request_handler(response: Response) -> Response: +def sync_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]: assert isinstance(response, Response) response.content = {"hello": "moon"} return response -async def async_after_request_handler(response: Response) -> Response: +async def async_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]: assert isinstance(response, Response) response.content = {"hello": "moon"} return response -async def async_after_request_handler_with_hello_world(response: Response) -> Response: +async def async_after_request_handler_with_hello_world(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]: assert isinstance(response, Response) response.content = {"hello": "world"} return response @@ -34,9 +34,11 @@ async def async_after_request_handler_with_hello_world(response: Response) -> Re [async_after_request_handler, {"hello": "moon"}], ], ) -def test_after_request_handler_called(after_request: Optional[AfterRequestHookHandler], expected: dict) -> None: +def test_after_request_handler_called( + after_request: Optional[AfterRequestHookHandler], expected: Dict[str, str] +) -> None: @get(after_request=after_request) - def handler() -> dict: + def handler() -> Dict[str, str]: return {"hello": "world"} with create_test_client(route_handlers=handler) as client: @@ -63,7 +65,7 @@ def test_after_request_handler_resolution( router_after_request_handler: Optional[AfterRequestHookHandler], controller_after_request_handler: Optional[AfterRequestHookHandler], method_after_request_handler: Optional[AfterRequestHookHandler], - expected: dict, + expected: Dict[str, str], ) -> None: class MyController(Controller): path = "/hello" @@ -71,7 +73,7 @@ class MyController(Controller): after_request = controller_after_request_handler @get(after_request=method_after_request_handler) - def hello(self) -> dict: + def hello(self) -> Dict[str, str]: return {"hello": "world"} router = Router(path="/greetings", route_handlers=[MyController], after_request=router_after_request_handler) @@ -82,12 +84,12 @@ def hello(self) -> dict: def test_after_request_handles_handlers_that_return_responses() -> None: - def after_request(response: Response) -> Response: + def after_request(response: Response[Any]) -> Response[Any]: response.headers["Custom-Header-Name"] = "Custom Header Value" return response @get("/") - def handler() -> Response: + def handler() -> Response[str]: return Response("test") with create_test_client(handler, after_request=after_request) as client: diff --git a/tests/e2e/test_life_cycle_hooks/test_before_request.py b/tests/e2e/test_life_cycle_hooks/test_before_request.py index 5efea5516c..47de0a58a8 100644 --- a/tests/e2e/test_life_cycle_hooks/test_before_request.py +++ b/tests/e2e/test_life_cycle_hooks/test_before_request.py @@ -1,37 +1,38 @@ -from typing import Optional +from typing import Any, Dict, Optional import pytest from litestar import Controller, Request, Response, Router, get +from litestar.datastructures import State from litestar.testing import create_test_client from litestar.types import AnyCallable, BeforeRequestHookHandler -def sync_before_request_handler_with_return_value(request: Request) -> dict: +def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]: assert isinstance(request, Request) return {"hello": "moon"} -async def async_before_request_handler_with_return_value(request: Request) -> dict: +async def async_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]: assert isinstance(request, Request) return {"hello": "moon"} -def sync_before_request_handler_without_return_value(request: Request) -> None: +def sync_before_request_handler_without_return_value(request: Request[Any, Any, State]) -> None: assert isinstance(request, Request) -async def async_before_request_handler_without_return_value(request: Request) -> None: +async def async_before_request_handler_without_return_value(request: Request[Any, Any, State]) -> None: assert isinstance(request, Request) -def sync_after_request_handler(response: Response) -> Response: +def sync_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]: assert isinstance(response, Response) response.content = {"hello": "moon"} return response -async def async_after_request_handler(response: Response) -> Response: +async def async_after_request_handler(response: Response[Dict[str, str]]) -> Response[Dict[str, str]]: assert isinstance(response, Response) response.content = {"hello": "moon"} return response @@ -47,9 +48,9 @@ async def async_after_request_handler(response: Response) -> Response: (async_before_request_handler_without_return_value, {"hello": "world"}), ), ) -def test_before_request_handler_called(before_request: Optional[AnyCallable], expected: dict) -> None: +def test_before_request_handler_called(before_request: Optional[AnyCallable], expected: Dict[str, str]) -> None: @get(before_request=before_request) - def handler() -> dict: + def handler() -> Dict[str, str]: return {"hello": "world"} with create_test_client(route_handlers=handler) as client: @@ -94,7 +95,7 @@ def test_before_request_handler_resolution( router_before_request_handler: Optional[BeforeRequestHookHandler], controller_before_request_handler: Optional[BeforeRequestHookHandler], method_before_request_handler: Optional[BeforeRequestHookHandler], - expected: dict, + expected: Dict[str, str], ) -> None: class MyController(Controller): path = "/hello" @@ -102,7 +103,7 @@ class MyController(Controller): before_request = controller_before_request_handler @get(before_request=method_before_request_handler) - def hello(self) -> dict: + def hello(self) -> Dict[str, str]: return {"hello": "world"} router = Router(path="/greetings", route_handlers=[MyController], before_request=router_before_request_handler) diff --git a/tests/e2e/test_response_caching.py b/tests/e2e/test_response_caching.py index 7573178154..4623028431 100644 --- a/tests/e2e/test_response_caching.py +++ b/tests/e2e/test_response_caching.py @@ -1,7 +1,7 @@ import gzip import random from datetime import timedelta -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union from unittest.mock import MagicMock from uuid import uuid4 @@ -11,6 +11,7 @@ from litestar import Litestar, Request, Response, get, post from litestar.config.compression import CompressionConfig from litestar.config.response_cache import CACHE_FOREVER, ResponseCacheConfig +from litestar.datastructures import State from litestar.enums import CompressionEncoding from litestar.middleware.response_cache import ResponseCacheMiddleware from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR @@ -22,13 +23,15 @@ if TYPE_CHECKING: from time_machine import Coordinates +T = TypeVar("T") + @pytest.fixture() def mock() -> MagicMock: return MagicMock(return_value=str(random.random())) -def after_request_handler(response: "Response") -> "Response": +def after_request_handler(response: "Response[T]") -> "Response[T]": response.headers["unique-identifier"] = str(uuid4()) return response @@ -132,7 +135,7 @@ async def handler() -> None: @pytest.mark.parametrize("sync_to_thread", (True, False)) async def test_custom_cache_key(sync_to_thread: bool, anyio_backend: str, mock: MagicMock) -> None: - def custom_cache_key_builder(request: Request) -> str: + def custom_cache_key_builder(request: Request[Any, Any, State]) -> str: return f"{request.url.path}:::cached" @get("/cached", sync_to_thread=sync_to_thread, cache=True, cache_key_builder=custom_cache_key_builder) @@ -262,7 +265,7 @@ def test_default_do_response_cache_predicate( mock: MagicMock, response: Union[int, Type[RuntimeError]], should_cache: bool ) -> None: @get("/", cache=True) - def handler() -> Response: + def handler() -> Response[None]: mock() if isinstance(response, int): return Response(None, status_code=response) diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 20b9a4a116..87a779628c 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from dataclasses import fields -from typing import TYPE_CHECKING, Callable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Tuple from unittest.mock import MagicMock, Mock, PropertyMock import pytest @@ -258,7 +258,7 @@ def test_using_custom_http_exception_handler() -> None: @get("/{param:int}") def my_route_handler(param: int) -> None: ... - def my_custom_handler(_: Request, __: Exception) -> Response: + def my_custom_handler(_: Request[Any, Any, State], __: Exception) -> Response[str]: return Response(content="custom message", media_type=MediaType.TEXT, status_code=HTTP_400_BAD_REQUEST) with create_test_client(my_route_handler, exception_handlers={NotFoundException: my_custom_handler}) as client: diff --git a/tests/unit/test_connection/test_request.py b/tests/unit/test_connection/test_request.py index 7393211647..ec532852d7 100644 --- a/tests/unit/test_connection/test_request.py +++ b/tests/unit/test_connection/test_request.py @@ -12,8 +12,8 @@ import pytest from litestar import MediaType, Request, asgi, get, post -from litestar.connection.base import empty_send -from litestar.datastructures import Address, Cookie +from litestar.connection.base import AuthT, StateT, UserT, empty_send +from litestar.datastructures import Address, Cookie, State from litestar.exceptions import ( InternalServerException, LitestarException, @@ -44,40 +44,40 @@ def scope_fixture(create_scope: Callable[..., Scope]) -> Scope: async def test_request_empty_body_to_json(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=b""): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) request_json = await request_empty_payload.json() assert request_json is None async def test_request_invalid_body_to_json(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=b"invalid"), pytest.raises(SerializationException): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) await request_empty_payload.json() async def test_request_valid_body_to_json(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=b'{"test": "valid"}'): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) request_json = await request_empty_payload.json() assert request_json == {"test": "valid"} async def test_request_empty_body_to_msgpack(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=b""): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) request_msgpack = await request_empty_payload.msgpack() assert request_msgpack is None async def test_request_invalid_body_to_msgpack(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=b"invalid"), pytest.raises(SerializationException): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) await request_empty_payload.msgpack() async def test_request_valid_body_to_msgpack(anyio_backend: str, scope: Scope) -> None: with patch.object(Request, "body", return_value=encode_msgpack({"test": "valid"})): - request_empty_payload: Request = Request(scope=scope) + request_empty_payload: Request[Any, Any, State] = Request(scope=scope) request_msgpack = await request_empty_payload.msgpack() assert request_msgpack == {"test": "valid"} @@ -88,11 +88,11 @@ def proxy() -> None: pass @get(path="/test", signature_namespace={"dict": Dict}) - def root(request: Request) -> dict[str, str]: + def root(request: Request[Any, Any, State]) -> dict[str, str]: return {"url": request.url_for("proxy")} @get(path="/test-none", signature_namespace={"dict": Dict}) - def test_none(request: Request) -> dict[str, str]: + def test_none(request: Request[Any, Any, State]) -> dict[str, str]: return {"url": request.url_for("none")} with create_test_client(route_handlers=[proxy, root, test_none]) as client: @@ -105,11 +105,11 @@ def test_none(request: Request) -> dict[str, str]: def test_request_asset_url(tmp_path: Path) -> None: @get(path="/resolver", signature_namespace={"dict": Dict}) - def resolver(request: Request) -> dict[str, str]: + def resolver(request: Request[Any, Any, State]) -> dict[str, str]: return {"url": request.url_for_static_asset("js", "main.js")} @get(path="/resolver-none", signature_namespace={"dict": Dict}) - def resolver_none(request: Request) -> dict[str, str]: + def resolver_none(request: Request[Any, Any, State]) -> dict[str, str]: return {"url": request.url_for_static_asset("none", "main.js")} with create_test_client( @@ -127,7 +127,7 @@ def test_route_handler_property() -> None: value: Any = {} @get("/") - def handler(request: Request) -> None: + def handler(request: Request[Any, Any, State]) -> None: value["handler"] = request.route_handler with create_test_client(route_handlers=[handler]) as client: @@ -138,13 +138,13 @@ def handler(request: Request) -> None: def test_custom_request_class() -> None: value: Any = {} - class MyRequest(Request): + class MyRequest(Request[UserT, AuthT, StateT]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.scope["called"] = True # type: ignore[typeddict-unknown-key] @get("/", signature_types=[MyRequest]) - def handler(request: MyRequest) -> None: + def handler(request: MyRequest[Any, Any, State]) -> None: value["called"] = request.scope.get("called") with create_test_client(route_handlers=[handler], request_class=MyRequest) as client: @@ -154,7 +154,7 @@ def handler(request: MyRequest) -> None: def test_request_url() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) data = {"method": request.method, "url": str(request.url)} response = ASGIResponse(body=encode_json(data)) await response(scope, receive, send) @@ -169,7 +169,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_query_params() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) params = dict(request.query_params) response = ASGIResponse(body=encode_json({"params": params})) await response(scope, receive, send) @@ -181,7 +181,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_headers() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) headers = dict(request.headers) response = ASGIResponse(body=encode_json({"headers": headers})) await response(scope, receive, send) @@ -201,7 +201,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_accept_header() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) response = ASGIResponse(body=encode_json({"accepted_types": list(request.accept)})) await response(scope, receive, send) @@ -225,13 +225,13 @@ def test_request_client( scope.update(scope_values) # type: ignore[typeddict-item] if "client" not in scope_values: del scope["client"] # type: ignore[misc] - client = Request[Any, Any, Any](scope).client + client = Request[Any, Any, State](scope).client assert client == expected_client def test_request_body() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) body = await request.body() response = ASGIResponse(body=encode_json({"body": body.decode()})) await response(scope, receive, send) @@ -250,7 +250,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_stream() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) body = b"" async for chunk in request.stream(): body += chunk @@ -271,7 +271,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_form_urlencoded() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) form = await request.form() response = ASGIResponse(body=encode_json({"form": dict(form)})) await response(scope, receive, send) @@ -302,7 +302,7 @@ async def handler(request: Request) -> int: def test_request_body_then_stream() -> None: async def app(scope: Any, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) body = await request.body() chunks = b"" async for chunk in request.stream(): @@ -318,7 +318,7 @@ async def app(scope: Any, receive: Receive, send: Send) -> None: def test_request_stream_then_body() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) chunks = b"" async for chunk in request.stream(): chunks += chunk @@ -338,7 +338,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_json() -> None: @asgi("/") async def handler(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) data = await request.json() response = ASGIResponse(body=encode_json({"json": data})) await response(scope, receive, send) @@ -350,7 +350,7 @@ async def handler(scope: Scope, receive: Receive, send: Send) -> None: def test_request_raw_path() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) path = str(request.scope["path"]) raw_path = str(request.scope["raw_path"]) response = ASGIResponse(body=f"{path}, {raw_path}".encode(), media_type=MediaType.TEXT) @@ -365,7 +365,7 @@ def test_request_without_setting_receive() -> None: """If Request is instantiated without the 'receive' channel, then .body() is not available.""" async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope) + request = Request[Any, Any, State](scope) try: data = await request.json() except RuntimeError: @@ -382,10 +382,10 @@ async def test_request_disconnect(create_scope: Callable[..., Scope]) -> None: """If a client disconnect occurs while reading request body then InternalServerException should be raised.""" async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) await request.body() - async def receiver() -> dict: + async def receiver() -> dict[str, str]: return {"type": "http.disconnect"} with pytest.raises(InternalServerException): @@ -398,10 +398,10 @@ async def receiver() -> dict: def test_request_state() -> None: @get("/", signature_namespace={"dict": Dict}) - def handler(request: Request[Any, Any, Any]) -> dict[Any, Any]: + def handler(request: Request[Any, Any, State]) -> dict[Any, Any]: request.state.test = 1 assert request.state.test == 1 - return request.state.dict() # type: ignore[no-any-return] + return request.state.dict() with create_test_client(handler) as client: response = client.get("/") @@ -410,7 +410,7 @@ def handler(request: Request[Any, Any, Any]) -> dict[Any, Any]: def test_request_cookies() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) mycookie = request.cookies.get("mycookie") if mycookie: asgi_response = ASGIResponse(body=mycookie.encode("utf-8"), media_type="text/plain") @@ -432,7 +432,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_chunked_encoding() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope, receive) + request = Request[Any, Any, State](scope, receive) body = await request.body() response = ASGIResponse(body=encode_json({"body": body.decode()})) await response(scope, receive, send) @@ -452,7 +452,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled scope["extensions"]["http.response.push"] = {} # type: ignore[index] - request = Request[Any, Any, Any](scope, receive, send) + request = Request[Any, Any, State](scope, receive, send) await request.send_push_promise("/style.css") response = ASGIResponse(body=encode_json({"json": "OK"})) @@ -470,7 +470,7 @@ def test_request_send_push_promise_without_push_extension() -> None: """ async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope) + request = Request[Any, Any, State](scope) with pytest.warns(LitestarWarning, match="Attempted to send a push promise"): await request.send_push_promise("/style.css") @@ -490,7 +490,7 @@ def test_request_send_push_promise_without_push_extension_raises() -> None: """ async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, Any](scope) + request = Request[Any, Any, State](scope) with pytest.raises(LitestarException, match="Attempted to send a push promise"): await request.send_push_promise("/style.css", raise_if_unavailable=True) @@ -512,7 +512,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: scope["extensions"]["http.response.push"] = {} # type: ignore[index] data = "OK" - request = Request[Any, Any, Any](scope) + request = Request[Any, Any, State](scope) try: await request.send_push_promise("/style.css") except RuntimeError: @@ -535,12 +535,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def test_state() -> None: - def before_request(request: Request) -> None: + def before_request(request: Request[Any, Any, State]) -> None: assert request.state.main == 1 request.state.main = 2 @get(path="/", signature_namespace={"dict": Dict}) - async def get_state(request: Request) -> dict[str, str]: + async def get_state(request: Request[Any, Any, State]) -> dict[str, str]: return {"state": request.state.main} with create_test_client( diff --git a/tests/unit/test_connection/test_websocket.py b/tests/unit/test_connection/test_websocket.py index 5e7a2e2d38..2fe96f9eb9 100644 --- a/tests/unit/test_connection/test_websocket.py +++ b/tests/unit/test_connection/test_websocket.py @@ -12,6 +12,7 @@ import pytest from litestar.connection import WebSocket +from litestar.datastructures import State from litestar.datastructures.headers import Headers from litestar.exceptions import WebSocketDisconnect, WebSocketException from litestar.handlers.websocket_handlers import websocket @@ -27,7 +28,7 @@ @pytest.mark.parametrize("mode", ["text", "binary"]) def test_websocket_send_receive_json(mode: Literal["text", "binary"]) -> None: @websocket(path="/") - async def websocket_handler(socket: WebSocket) -> None: + async def websocket_handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() recv = await socket.receive_json(mode=mode) await socket.send_json({"message": recv}, mode=mode) @@ -43,7 +44,7 @@ def test_route_handler_property() -> None: value: Any = {} @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() value["handler"] = socket.route_handler await socket.close() @@ -57,7 +58,7 @@ async def handler(socket: WebSocket) -> None: ) async def test_accept_set_headers(headers: Any) -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept(headers=headers) await socket.send_text("abc") await socket.close() @@ -69,7 +70,7 @@ async def handler(socket: WebSocket) -> None: async def test_custom_request_class() -> None: value: Any = {} - class MyWebSocket(WebSocket[Any, Any, Any]): + class MyWebSocket(WebSocket[Any, Any, State]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.scope["called"] = True # type: ignore[typeddict-unknown-key] @@ -86,7 +87,7 @@ async def handler(socket: MyWebSocket) -> None: def test_websocket_url() -> None: @websocket("/123") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.send_json({"url": str(socket.url)}) await socket.close() @@ -108,7 +109,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_binary_json() -> None: @websocket("/123") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() message = await socket.receive_json(mode="binary") await socket.send_json(message, mode="binary") @@ -121,7 +122,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_query_params() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: query_params = dict(socket.query_params) await socket.accept() await socket.send_json({"params": query_params}) @@ -133,7 +134,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_headers() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: headers = dict(socket.headers) await socket.accept() await socket.send_json({"headers": headers}) @@ -154,7 +155,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_port() -> None: @websocket("/123") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.send_json({"port": socket.url.port}) await socket.close() @@ -165,7 +166,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_send_and_receive_text() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() data = await socket.receive_text() await socket.send_text(f"Message was: {data}") @@ -178,7 +179,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_send_and_receive_bytes() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() data = await socket.receive_bytes() await socket.send_bytes(b"Message was: " + data) @@ -191,7 +192,7 @@ async def handler(socket: WebSocket) -> None: def test_websocket_send_and_receive_json() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() data = await socket.receive_json() await socket.send_json({"message": data}) @@ -206,7 +207,7 @@ def test_send_msgpack() -> None: test_data = {"message": "hello, world"} @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.send_msgpack(test_data) await socket.close() @@ -221,7 +222,7 @@ def test_receive_msgpack() -> None: callback = MagicMock() @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() data = await socket.receive_msgpack() callback(data) @@ -249,7 +250,7 @@ def test_iter_data(mode: WebSocketMode, data: list[str | bytes]) -> None: values = [] @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() values.extend(await consume_gen(socket.iter_data(mode=mode), 2)) await socket.close() @@ -267,7 +268,7 @@ def test_iter_json(mode: WebSocketMode) -> None: values = [] @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() values.extend(await consume_gen(socket.iter_json(mode=mode), 2)) await socket.close() @@ -284,7 +285,7 @@ def test_iter_msgpack() -> None: values = [] @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() values.extend(await consume_gen(socket.iter_msgpack(), 2)) await socket.close() @@ -299,18 +300,18 @@ async def handler(socket: WebSocket) -> None: def test_websocket_concurrency_pattern() -> None: stream_send, stream_receive = anyio.create_memory_object_stream() # type: ignore[var-annotated] - async def reader(socket: WebSocket[Any, Any, Any]) -> None: + async def reader(socket: WebSocket[Any, Any, State]) -> None: async with stream_send: json_data = await socket.receive_json() await stream_send.send(json_data) - async def writer(socket: WebSocket[Any, Any, Any]) -> None: + async def writer(socket: WebSocket[Any, Any, State]) -> None: async with stream_receive: async for message in stream_receive: await socket.send_json(message) @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() async with anyio.create_task_group() as task_group: task_group.start_soon(reader, socket) @@ -327,7 +328,7 @@ def test_client_close() -> None: close_code = None @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: nonlocal close_code await socket.accept() try: @@ -342,7 +343,7 @@ async def handler(socket: WebSocket) -> None: def test_application_close() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.close(WS_1001_GOING_AWAY) @@ -353,7 +354,7 @@ async def handler(socket: WebSocket) -> None: def test_rejected_connection() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.close(WS_1001_GOING_AWAY) with pytest.raises(WebSocketDisconnect) as exc, create_test_client(handler).websocket_connect("/"): @@ -363,7 +364,7 @@ async def handler(socket: WebSocket) -> None: def test_subprotocol() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: assert socket.scope["subprotocols"] == ["soap", "wamp"] await socket.accept(subprotocols="wamp") await socket.close() @@ -374,7 +375,7 @@ async def handler(socket: WebSocket) -> None: def test_additional_headers() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept(headers=[(b"additional", b"header")]) await socket.close() @@ -384,7 +385,7 @@ async def handler(socket: WebSocket) -> None: def test_no_additional_headers() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.close() @@ -402,7 +403,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_duplicate_disconnect() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - socket = WebSocket[Any, Any, Any](scope, receive=receive, send=send) + socket = WebSocket[Any, Any, State](scope, receive=receive, send=send) await socket.accept() message = await socket.receive() assert message["type"] == "websocket.disconnect" @@ -414,7 +415,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_websocket_close_reason() -> None: @websocket("/") - async def handler(socket: WebSocket) -> None: + async def handler(socket: WebSocket[Any, Any, State]) -> None: await socket.accept() await socket.close(code=WS_1001_GOING_AWAY, reason="Going Away") @@ -426,7 +427,7 @@ async def handler(socket: WebSocket) -> None: def test_receive_text_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - socket = WebSocket[Any, Any, Any](scope, receive=receive, send=send) + socket = WebSocket[Any, Any, State](scope, receive=receive, send=send) await socket.receive_text() with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): @@ -435,7 +436,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_receive_bytes_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - socket = WebSocket[Any, Any, Any](scope, receive=receive, send=send) + socket = WebSocket[Any, Any, State](scope, receive=receive, send=send) await socket.receive_bytes() with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): @@ -444,7 +445,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_receive_json_before_accept() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: - socket = WebSocket[Any, Any, Any](scope, receive=receive, send=send) + socket = WebSocket[Any, Any, State](scope, receive=receive, send=send) await socket.receive_json() with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index b8e0baea48..bb5c213397 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -5,7 +5,7 @@ from http import HTTPStatus from pathlib import Path from types import ModuleType -from typing import Any, Callable, Dict, TypedDict +from typing import Any, Callable, Dict, TypedDict, TypeVar from unittest.mock import MagicMock import pytest @@ -31,7 +31,6 @@ from litestar.openapi.spec import Example, OpenAPIHeader, OpenAPIMediaType, Reference, Schema from litestar.openapi.spec.enums import OpenAPIType from litestar.response import File, Redirect, Stream, Template -from litestar.response.base import T from litestar.routes import HTTPRoute from litestar.status_codes import ( HTTP_200_OK, @@ -44,6 +43,9 @@ from tests.models import DataclassPerson, DataclassPersonFactory from tests.unit.test_openapi.utils import PetException +T = TypeVar("T") + + CreateFactoryFixture: TypeAlias = "Callable[..., ResponseFactory]"