diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c60d31d4..19a1acdd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/mangum/adapter.py b/mangum/adapter.py index 5c299863..52dffc00 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -2,7 +2,6 @@ from contextlib import ExitStack from typing import ( Any, - ContextManager, Dict, Optional, TYPE_CHECKING, @@ -44,49 +43,45 @@ class Mangum: * **text_mime_types** - A list of MIME types to include with the defaults that should not return a binary response in API Gateway. * **dsn** - A connection string required to configure a supported WebSocket backend. + * **api_gateway_base_path** - A string specifying the part of the url path after + which the server routing begins. * **api_gateway_endpoint_url** - A string endpoint url to use for API Gateway when sending data to WebSocket connections. Default is to determine this automatically. * **api_gateway_region_name** - A string region name to use for API Gateway when sending data to WebSocket connections. Default is `AWS_REGION` environment variable. """ - app: ASGIApp - lifespan: str = "auto" - dsn: Optional[str] = None - api_gateway_endpoint_url: Optional[str] = None - api_gateway_region_name: Optional[str] = None - def __init__( self, app: ASGIApp, lifespan: str = "auto", dsn: Optional[str] = None, + api_gateway_base_path: str = "/", api_gateway_endpoint_url: Optional[str] = None, api_gateway_region_name: Optional[str] = None, - **handler_kwargs: Dict[str, Any] ) -> None: self.app = app self.lifespan = lifespan self.dsn = dsn + self.api_gateway_base_path = api_gateway_base_path self.api_gateway_endpoint_url = api_gateway_endpoint_url self.api_gateway_region_name = api_gateway_region_name - self.handler_kwargs = handler_kwargs if self.lifespan not in ("auto", "on", "off"): raise ConfigurationError( "Invalid argument supplied for `lifespan`. Choices are: auto|on|off" ) - def __call__(self, event: dict, context: "LambdaContext") -> dict: + def __call__(self, event: Dict[str, Any], context: "LambdaContext") -> dict: logger.debug("Event received.") with ExitStack() as stack: if self.lifespan != "off": - lifespan_cycle: ContextManager = LifespanCycle(self.app, self.lifespan) + lifespan_cycle = LifespanCycle(self.app, self.lifespan) stack.enter_context(lifespan_cycle) handler = AbstractHandler.from_trigger( - event, context, **self.handler_kwargs + event, context, self.api_gateway_base_path ) request = handler.request diff --git a/mangum/backends/__init__.py b/mangum/backends/__init__.py index 29a0ac39..fd71c010 100644 --- a/mangum/backends/__init__.py +++ b/mangum/backends/__init__.py @@ -1,13 +1,14 @@ import asyncio import logging -from typing import Dict, Optional +from typing import Dict, Optional, Type import json from functools import partial -from dataclasses import dataclass, InitVar +from dataclasses import dataclass from urllib.parse import urlparse try: import httpx + from httpx import AsyncClient, Response except ImportError: # pragma: no cover httpx = None # type: ignore @@ -30,6 +31,8 @@ def get_sigv4_headers( data: Optional[bytes] = None, region_name: Optional[str] = None, ) -> Dict: + if boto3 is None: # pragma: no cover + raise WebSocketError("boto3 must be installed to use WebSockets.") session = boto3.Session() credentials = session.get_credentials() creds = credentials.get_frozen_credentials() @@ -48,24 +51,21 @@ class WebSocket: selected `WebSocketBackend` subclass """ - dsn: InitVar[Optional[str]] + dsn: Optional[str] api_gateway_endpoint_url: str api_gateway_region_name: Optional[str] = None - def __post_init__(self, dsn: Optional[str]) -> None: - if boto3 is None: # pragma: no cover - raise WebSocketError("boto3 must be installed to use WebSockets.") - - if httpx is None: # pragma: no cover + def __post_init__(self) -> None: + if not httpx: # pragma: no cover raise WebSocketError("httpx must be installed to use WebSockets.") - if dsn is None: + if self.dsn is None: raise ConfigurationError( "The `dsn` parameter must be provided for WebSocket connections." ) self.logger: logging.Logger = logging.getLogger("mangum.backends") - parsed_dsn = urlparse(dsn) + parsed_dsn = urlparse(self.dsn) if not any((parsed_dsn.hostname, parsed_dsn.path)): raise ConfigurationError("Invalid value for `dsn` provided.") @@ -74,6 +74,7 @@ def __post_init__(self, dsn: Optional[str]) -> None: f"Attempting WebSocket backend connection using scheme: {scheme}" ) + self._Backend: Type[WebSocketBackend] if scheme == "sqlite": self.logger.info( "The `SQLiteBackend` should be only be used for local " @@ -81,33 +82,31 @@ def __post_init__(self, dsn: Optional[str]) -> None: ) from mangum.backends.sqlite import SQLiteBackend - self._Backend = SQLiteBackend # type: ignore + self._Backend = SQLiteBackend elif scheme == "dynamodb": from mangum.backends.dynamodb import DynamoDBBackend - self._Backend = DynamoDBBackend # type: ignore + self._Backend = DynamoDBBackend elif scheme == "s3": from mangum.backends.s3 import S3Backend - self._Backend = S3Backend # type: ignore + self._Backend = S3Backend elif scheme in ("postgresql", "postgres"): from mangum.backends.postgresql import PostgreSQLBackend - self._Backend = PostgreSQLBackend # type: ignore + self._Backend = PostgreSQLBackend elif scheme == "redis": from mangum.backends.redis import RedisBackend - self._Backend = RedisBackend # type: ignore + self._Backend = RedisBackend else: raise ConfigurationError(f"{scheme} does not match a supported backend.") - self.dsn = dsn - self.logger.info("WebSocket backend connection established.") async def load_scope(self, backend: WebSocketBackend, connection_id: str) -> Scope: @@ -153,18 +152,18 @@ async def on_disconnect(self, connection_id: str) -> None: await backend.delete(connection_id) async def post_to_connection(self, connection_id: str, body: bytes) -> None: - async with httpx.AsyncClient() as client: + async with AsyncClient() as client: await self._post_to_connection(connection_id, client=client, body=body) async def delete_connection(self, connection_id: str) -> None: - async with httpx.AsyncClient() as client: + async with AsyncClient() as client: await self._request_to_connection("DELETE", connection_id, client=client) async def _post_to_connection( self, connection_id: str, *, - client: "httpx.AsyncClient", + client: "AsyncClient", body: bytes, ) -> None: # pragma: no cover response = await self._request_to_connection( @@ -181,9 +180,9 @@ async def _request_to_connection( method: str, connection_id: str, *, - client: "httpx.AsyncClient", + client: "AsyncClient", body: Optional[bytes] = None, - ) -> "httpx.Response": + ) -> "Response": loop = asyncio.get_event_loop() url = f"{self.api_gateway_endpoint_url}/{connection_id}" headers = await loop.run_in_executor( diff --git a/mangum/handlers/abstract_handler.py b/mangum/handlers/abstract_handler.py index 25c294d9..4c7ce7ab 100644 --- a/mangum/handlers/abstract_handler.py +++ b/mangum/handlers/abstract_handler.py @@ -13,7 +13,6 @@ def __init__( self, trigger_event: Dict[str, Any], trigger_context: "LambdaContext", - **kwargs: Dict[str, Any], ): self.trigger_event = trigger_event self.trigger_context = trigger_context @@ -62,7 +61,7 @@ def api_gateway_endpoint_url(self) -> str: def from_trigger( trigger_event: Dict[str, Any], trigger_context: "LambdaContext", - **kwargs: Dict[str, Any], + api_gateway_base_path: str = "/", ) -> "AbstractHandler": """ A factory method that determines which handler to use. All this code should @@ -77,7 +76,7 @@ def from_trigger( ): from . import AwsAlb - return AwsAlb(trigger_event, trigger_context, **kwargs) + return AwsAlb(trigger_event, trigger_context) if ( "requestContext" in trigger_event @@ -85,9 +84,7 @@ def from_trigger( ): from . import AwsWsGateway - return AwsWsGateway( - trigger_event, trigger_context, **kwargs # type: ignore - ) + return AwsWsGateway(trigger_event, trigger_context) if ( "Records" in trigger_event @@ -96,20 +93,24 @@ def from_trigger( ): from . import AwsCfLambdaAtEdge - return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs) + return AwsCfLambdaAtEdge(trigger_event, trigger_context) if "version" in trigger_event and "requestContext" in trigger_event: from . import AwsHttpGateway return AwsHttpGateway( - trigger_event, trigger_context, **kwargs # type: ignore + trigger_event, + trigger_context, + api_gateway_base_path, ) if "resource" in trigger_event: from . import AwsApiGateway return AwsApiGateway( - trigger_event, trigger_context, **kwargs # type: ignore + trigger_event, + trigger_context, + api_gateway_base_path, ) raise TypeError("Unable to determine handler from trigger event") diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py index 4a918f00..f178c323 100644 --- a/mangum/handlers/aws_alb.py +++ b/mangum/handlers/aws_alb.py @@ -1,13 +1,15 @@ import base64 -import urllib.parse +from urllib.parse import urlencode, unquote, unquote_plus from typing import Any, Dict, Generator, List, Tuple from itertools import islice +from mangum.types import QueryParams + from .abstract_handler import AbstractHandler from .. import Response, Request -def all_casings(input_string: str) -> Generator: +def all_casings(input_string: str) -> Generator[str, None, None]: """ Permute all casings of a given string. A pretty algoritm, via @Amber @@ -28,7 +30,7 @@ def all_casings(input_string: str) -> Generator: def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]: """Create str/str key/value headers, with duplicate keys case mutated.""" - headers = {} + headers: Dict[str, str] = {} for key, values in multi_value_headers.items(): if len(values) > 0: casings = list(islice(all_casings(key), len(values))) @@ -49,7 +51,7 @@ class AwsAlb(AbstractHandler): TYPE = "AWS_ALB" - def encode_query_string(self) -> bytes: + def _encode_query_string(self) -> bytes: """ Encodes the queryStringParameters. The parameters must be decoded, and then encoded again to prevent double @@ -62,28 +64,20 @@ def encode_query_string(self) -> bytes: Issue: https://github.com/jordaneremieff/mangum/issues/178 """ - params = self.trigger_event.get("multiValueQueryStringParameters") + params: QueryParams = self.trigger_event.get( + "multiValueQueryStringParameters", {} + ) if not params: - params = self.trigger_event.get("queryStringParameters") + params = self.trigger_event.get("queryStringParameters", {}) if not params: - return b"" # No query parameters, exit early with an empty byte string. - - # Loop through the query parameters, unquote each key and value and append the - # pair as a tuple to the query list. If value is a list or a tuple, loop - # through the nested struture and unqote. - query = [] - for key, value in params.items(): - if isinstance(value, (tuple, list)): - for v in value: - query.append( - (urllib.parse.unquote_plus(key), urllib.parse.unquote_plus(v)) - ) - else: - query.append( - (urllib.parse.unquote_plus(key), urllib.parse.unquote_plus(value)) - ) - - return urllib.parse.urlencode(query).encode() + return b"" + params = { + unquote_plus(key): unquote_plus(value) + if isinstance(value, str) + else tuple(unquote_plus(element) for element in value) + for key, value in params.items() + } + return urlencode(params, doseq=True).encode() def transform_headers(self) -> List[Tuple[bytes, bytes]]: """Convert headers to a list of two-tuples per ASGI spec. @@ -92,7 +86,7 @@ def transform_headers(self) -> List[Tuple[bytes, bytes]]: trigger event. However, we act as though they both might exist and pull headers out of both. """ - headers = [] + headers: List[Tuple[bytes, bytes]] = [] if "multiValueHeaders" in self.trigger_event: for k, v in self.trigger_event["multiValueHeaders"].items(): for inner_v in v: @@ -112,9 +106,9 @@ def request(self) -> Request: uq_headers = {k.decode(): v.decode() for k, v in headers} source_ip = uq_headers.get("x-forwarded-for", "") - path = event["path"] + path = unquote(event["path"]) if event["path"] else "/" http_method = event["httpMethod"] - query_string = self.encode_query_string() + query_string = self._encode_query_string() server_name = uq_headers.get("host", "mangum") if ":" not in server_name: @@ -124,13 +118,10 @@ def request(self) -> Request: server = (server_name, int(server_port)) client = (source_ip, 0) - if not path: - path = "/" - return Request( method=http_method, headers=list_headers, - path=urllib.parse.unquote(path), + path=path, scheme=uq_headers.get("x-forwarded-proto", "https"), query_string=query_string, server=server, diff --git a/mangum/handlers/aws_api_gateway.py b/mangum/handlers/aws_api_gateway.py index eb4a51f6..c20558f2 100644 --- a/mangum/handlers/aws_api_gateway.py +++ b/mangum/handlers/aws_api_gateway.py @@ -1,7 +1,9 @@ import base64 -import urllib.parse +from urllib.parse import urlencode, unquote from typing import Dict, Any, TYPE_CHECKING +from mangum.types import QueryParams + from .abstract_handler import AbstractHandler from .. import Response, Request @@ -24,45 +26,39 @@ def __init__( self, trigger_event: Dict[str, Any], trigger_context: "LambdaContext", - api_gateway_base_path: str = "/", - **kwargs: Dict[str, Any], # type: ignore + api_gateway_base_path: str, ): - super().__init__(trigger_event, trigger_context, **kwargs) + super().__init__(trigger_event, trigger_context) self.api_gateway_base_path = api_gateway_base_path @property def request(self) -> Request: event = self.trigger_event - # multiValue versions of headers take precedence over their plain versions - # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format + # See this for more info on headers: + # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#apigateway-multivalue-headers-and-parameters + headers = {} + # Read headers + if event.get("headers"): + headers.update({k.lower(): v for k, v in event.get("headers", {}).items()}) + # Read multiValueHeaders + # This overrides headers that have the same name + # That means that multiValue versions of headers take precedence + # over their plain versions if event.get("multiValueHeaders"): - headers = { - k.lower(): ", ".join(v) if isinstance(v, list) else "" - for k, v in event.get("multiValueHeaders", {}).items() - } - elif event.get("headers"): - headers = {k.lower(): v for k, v in event.get("headers", {}).items()} - else: - headers = {} + headers.update( + { + k.lower(): ", ".join(v) if isinstance(v, list) else "" + for k, v in event.get("multiValueHeaders", {}).items() + } + ) request_context = event["requestContext"] source_ip = request_context.get("identity", {}).get("sourceIp") - - path = event["path"] + path = unquote(self._strip_base_path(event["path"])) if event["path"] else "/" http_method = event["httpMethod"] - - if event.get("multiValueQueryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("multiValueQueryStringParameters", {}), doseq=True - ).encode() - elif event.get("queryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("queryStringParameters", {}) - ).encode() - else: - query_string = b"" + query_string = self._encode_query_string() server_name = headers.get("host", "mangum") if ":" not in server_name: @@ -72,15 +68,10 @@ def request(self) -> Request: server = (server_name, int(server_port)) client = (source_ip, 0) - if not path: - path = "/" - else: - path = self._strip_base_path(path) - return Request( method=http_method, headers=[[k.encode(), v.encode()] for k, v in headers.items()], - path=urllib.parse.unquote(path), + path=path, scheme=headers.get("x-forwarded-proto", "https"), query_string=query_string, server=server, @@ -90,6 +81,20 @@ def request(self) -> Request: event_type=self.TYPE, ) + def _encode_query_string(self) -> bytes: + """ + Encodes the queryStringParameters. + """ + + params: QueryParams = self.trigger_event.get( + "multiValueQueryStringParameters", {} + ) + if not params: + params = self.trigger_event.get("queryStringParameters", {}) + if not params: + return b"" + return urlencode(params, doseq=True).encode() + def _strip_base_path(self, path: str) -> str: if self.api_gateway_base_path and self.api_gateway_base_path != "/": if not self.api_gateway_base_path.startswith("/"): diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py index 2d8982da..f6372082 100644 --- a/mangum/handlers/aws_http_gateway.py +++ b/mangum/handlers/aws_http_gateway.py @@ -1,6 +1,6 @@ import base64 import urllib.parse -from typing import Dict, Any +from typing import Dict, Any, List, Tuple from . import AwsApiGateway from .. import Response, Request @@ -52,24 +52,9 @@ def request(self) -> Request: ) source_ip = request_context.get("identity", {}).get("sourceIp") - path = event["path"] http_method = event["httpMethod"] - - # AWS Blog Post on this: - # https://aws.amazon.com/blogs/compute/support-for-multi-value-parameters-in-amazon-api-gateway/ # noqa: E501 - # A multi value param will be in multi value _and_ regular - # queryStringParameters. Multi value takes precedence. - if event.get("multiValueQueryStringParameters", False): - query_string = urllib.parse.urlencode( - event.get("multiValueQueryStringParameters", {}), doseq=True - ).encode() - elif event.get("queryStringParameters", False): - query_string = urllib.parse.urlencode( - event.get("queryStringParameters", {}) - ).encode() - else: - query_string = b"" + query_string = self._encode_query_string() else: raise RuntimeError( "Unsupported version of HTTP Gateway Spec, only v1.0 and v2.0 are " @@ -122,37 +107,65 @@ def transform_response(self, response: Response) -> Dict[str, Any]: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response """ + if self.event_version == "1.0": + return self.transform_response_v1(response) + elif self.event_version == "2.0": + return self.transform_response_v2(response) + raise RuntimeError( # pragma: no cover + "Misconfigured event unable to return value, unsupported version." + ) + + def transform_response_v1(self, response: Response) -> Dict[str, Any]: headers, multi_value_headers = self._handle_multi_value_headers( response.headers ) - if self.event_version == "1.0": - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - elif self.event_version == "2.0": - # The API Gateway will infer stuff for us, but we'll just do that inference - # here and keep the output consistent - if "content-type" not in headers and response.body is not None: - headers["content-type"] = "application/json" + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } + + def _combine_headers_v2( + self, input_headers: List[List[bytes]] + ) -> Tuple[Dict[str, str], List[str]]: + output_headers: Dict[str, str] = {} + cookies: List[str] = [] + for key, value in input_headers: + normalized_key: str = key.decode().lower() + normalized_value: str = value.decode() + if normalized_key == "set-cookie": + cookies.append(normalized_value) + else: + if normalized_key in output_headers: + normalized_value = ( + f"{output_headers[normalized_key]},{normalized_value}" + ) + output_headers[normalized_key] = normalized_value + return output_headers, cookies - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - raise RuntimeError( # pragma: no cover - "Misconfigured event unable to return value, unsupported version." + def transform_response_v2(self, response_in: Response) -> Dict[str, Any]: + # The API Gateway will infer stuff for us, but we'll just do that inference + # here and keep the output consistent + + headers, cookies = self._combine_headers_v2(response_in.headers) + + if "content-type" not in headers and response_in.body is not None: + headers["content-type"] = "application/json" + + body, is_base64_encoded = self._handle_base64_response_body( + response_in.body, headers ) + response_out = { + "statusCode": response_in.status, + "body": body, + "headers": headers or None, + "cookies": cookies or None, + "isBase64Encoded": is_base64_encoded, + } + return {key: value for key, value in response_out.items() if value is not None} diff --git a/mangum/handlers/aws_ws_gateway.py b/mangum/handlers/aws_ws_gateway.py index a1dae734..3fdf6e97 100644 --- a/mangum/handlers/aws_ws_gateway.py +++ b/mangum/handlers/aws_ws_gateway.py @@ -6,7 +6,7 @@ from .abstract_handler import AbstractHandler -def get_server_and_headers(event: dict) -> Tuple: # pragma: no cover +def get_server_and_headers(event: Dict[str, Any]) -> Tuple: # pragma: no cover if event.get("multiValueHeaders"): headers = { k.lower(): ", ".join(v) if isinstance(v, list) else "" diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index e6da9058..5b33636e 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -51,7 +51,7 @@ class HTTPCycle: def __post_init__(self) -> None: self.logger: logging.Logger = logging.getLogger("mangum.http") self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue = asyncio.Queue() + self.app_queue: asyncio.Queue[Message] = asyncio.Queue() self.body: BytesIO = BytesIO() def __call__(self, app: ASGIApp, initial_body: bytes) -> Response: diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index 790277bb..4ceb91dd 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -1,7 +1,7 @@ import asyncio import logging -import types -import typing +from types import TracebackType +from typing import Optional, Type import enum from dataclasses import dataclass @@ -44,26 +44,25 @@ class LifespanCycle: and `off`. Default is `auto`. * **state** - An enumerated `LifespanCycleState` type that indicates the state of the ASGI connection. - * **exception** - An exception raised while handling the ASGI event. + * **exception** - An exception raised while handling the ASGI event. This may or + may not be raised depending on the state. * **app_queue** - An asyncio queue (FIFO) containing messages to be received by the application. * **startup_event** - An asyncio event object used to control the application startup flow. * **shutdown_event** - An asyncio event object used to control the application shutdown flow. - * **exception** - An exception raised while handling the ASGI event. This may or - may not be raised depending on the state. """ app: ASGIApp lifespan: str state: LifespanCycleState = LifespanCycleState.CONNECTING - exception: typing.Optional[BaseException] = None + exception: Optional[BaseException] = None def __post_init__(self) -> None: self.logger = logging.getLogger("mangum.lifespan") self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue = asyncio.Queue() + self.app_queue: asyncio.Queue[Message] = asyncio.Queue() self.startup_event: asyncio.Event = asyncio.Event() self.shutdown_event: asyncio.Event = asyncio.Event() @@ -76,9 +75,9 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: typing.Optional[typing.Type[BaseException]], - exc_value: typing.Optional[BaseException], - traceback: typing.Optional[types.TracebackType], + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], ) -> None: """ Runs the event loop for application shutdown. diff --git a/mangum/protocols/websockets.py b/mangum/protocols/websockets.py index c305e3ce..ef15725b 100644 --- a/mangum/protocols/websockets.py +++ b/mangum/protocols/websockets.py @@ -1,5 +1,7 @@ import enum import asyncio +import copy +import typing import logging from io import BytesIO from dataclasses import dataclass @@ -61,7 +63,7 @@ class WebSocketCycle: def __post_init__(self) -> None: self.logger: logging.Logger = logging.getLogger("mangum.websocket") self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue = asyncio.Queue() + self.app_queue: asyncio.Queue[typing.Dict[str, typing.Any]] = asyncio.Queue() self.body: BytesIO = BytesIO() self.response: Response = Response(200, [], b"") @@ -93,7 +95,7 @@ async def run(self, app: ASGIApp) -> None: Calls the application with the `websocket` connection scope. """ self.scope = await self.websocket.on_message(self.connection_id) - scope = self.scope.copy() # type: ignore + scope = copy.copy(self.scope) scope.update( { "aws.event": self.request.trigger_event, diff --git a/mangum/types.py b/mangum/types.py index 3946f1c0..c1a42917 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -6,18 +6,19 @@ Any, Union, Optional, + Sequence, MutableMapping, Awaitable, Callable, TYPE_CHECKING, ) -from typing_extensions import Protocol +from typing_extensions import Protocol, TypeAlias - -Message = MutableMapping[str, Any] -Scope = MutableMapping[str, Any] -Receive = Callable[[], Awaitable[Message]] -Send = Callable[[Message], Awaitable[None]] +QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]] +Message: TypeAlias = MutableMapping[str, Any] +Scope: TypeAlias = MutableMapping[str, Any] +Receive: TypeAlias = Callable[[], Awaitable[Message]] +Send: TypeAlias = Callable[[Message], Awaitable[None]] if TYPE_CHECKING: # pragma: no cover @@ -53,6 +54,7 @@ class BaseRequest: root_path: str = "" asgi: Dict[str, str] = field(default_factory=lambda: {"version": "3.0"}) + @property def scope(self) -> Scope: return { "http_version": self.http_version, @@ -85,7 +87,7 @@ class Request(BaseRequest): @property def scope(self) -> Scope: - scope = super().scope() + scope = super().scope scope.update({"type": self.type, "method": self.method}) return scope @@ -103,7 +105,7 @@ class WsRequest(BaseRequest): @property def scope(self) -> Scope: - scope = super().scope() + scope = super().scope scope.update({"type": self.type, "subprotocols": self.subprotocols}) return scope diff --git a/requirements.txt b/requirements.txt index f409fbbd..ec8d79fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytest-cov black flake8 starlette -quart; python_version == '3.7' +quart; python_version >= '3.7' moto[server] mypy brotli diff --git a/setup.py b/setup.py index 3e32a15b..35951b22 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ def get_long_description(): setup( name="mangum", - version="0.12.3", + version="0.12.4", packages=find_packages(exclude=["tests*"]), license="MIT", url="https://github.com/jordaneremieff/mangum", @@ -26,5 +26,7 @@ def get_long_description(): "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], ) diff --git a/tests/conftest.py b/tests/conftest.py index c4ad042e..ab04eafc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,7 @@ def mock_aws_api_gateway_event(request): @pytest.fixture -def mock_http_api_event(request): +def mock_http_api_event_v2(request): method = request.param[0] body = request.param[1] multi_value_query_parameters = request.param[2] @@ -120,6 +120,67 @@ def mock_http_api_event(request): return event +@pytest.fixture +def mock_http_api_event_v1(request): + method = request.param[0] + body = request.param[1] + multi_value_query_parameters = request.param[2] + query_string = request.param[3] + event = { + "version": "1.0", + "routeKey": "$default", + "rawPath": "/my/path", + "path": "/my/path", + "httpMethod": method, + "rawQueryString": query_string, + "cookies": ["cookie1", "cookie2"], + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "queryStringParameters": { + k: v[-1] for k, v in multi_value_query_parameters.items() + } + if multi_value_query_parameters + else None, + "multiValueQueryStringParameters": { + k: v for k, v in multi_value_query_parameters.items() + } + if multi_value_query_parameters + else None, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "protocol": "HTTP/1.1", + "sourceIp": "192.168.100.1", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390, + }, + "body": body, + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": False, + "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, + } + + return event + + @pytest.fixture def mock_lambda_at_edge_event(request): method = request.param[0] diff --git a/tests/handlers/test_aws_api_gateway.py b/tests/handlers/test_aws_api_gateway.py index 1ce16211..eac4a2d3 100644 --- a/tests/handlers/test_aws_api_gateway.py +++ b/tests/handlers/test_aws_api_gateway.py @@ -47,7 +47,9 @@ def get_mock_aws_api_gateway_event( "cognitoAuthenticationType": "", "cognitoAuthenticationProvider": "", "userArn": "", - "userAgent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/52.0.2743.82 Safari/537.36 OPR/39.0.2256.48", # noqa: E501 + "userAgent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/52.0.2743.82 Safari/537.36 OPR/39.0.2256.48", "user": "", }, "resourcePath": "/{proxy+}", @@ -75,15 +77,21 @@ def test_aws_api_gateway_scope_basic(): "httpMethod": "GET", "requestContext": {"resourcePath": "/", "httpMethod": "GET", "path": "/Prod/"}, "headers": { - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", # noqa: E501 + "accept": "text/html,application/xhtml+xml,application/xml;" + "q=0.9,image/webp,image/apng,*/*;q=0.8," + "application/signed-exchange;v=b3;q=0.9", "accept-encoding": "gzip, deflate, br", "Host": "70ixmpl4fl.execute-api.us-east-2.amazonaws.com", - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.132 Safari/537.36", # noqa: E501 + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/80.0.3987.132 Safari/537.36", "X-Amzn-Trace-Id": "Root=1-5e66d96f-7491f09xmpl79d18acf3d050", }, "multiValueHeaders": { "accept": [ - "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" # noqa: E501 + "text/html,application/xhtml+xml,application/xml;" + "q=0.9,image/webp,image/apng,*/*;q=0.8," + "application/signed-exchange;v=b3;q=0.9" ], "accept-encoding": ["gzip, deflate, br"], }, @@ -95,7 +103,7 @@ def test_aws_api_gateway_scope_basic(): "isBase64Encoded": False, } example_context = {} - handler = AwsApiGateway(example_event, example_context) + handler = AwsApiGateway(example_event, example_context, "/") assert type(handler.body) == bytes assert handler.request.scope == { @@ -107,11 +115,20 @@ def test_aws_api_gateway_scope_basic(): "headers": [ [ b"accept", - b"text/html,application/xhtml+xml,application/xml;q=0.9," - b"image/webp,image/apng,*/*;q=0.8," + b"text/html,application/xhtml+xml,application/xml;" + b"q=0.9,image/webp,image/apng,*/*;q=0.8," b"application/signed-exchange;v=b3;q=0.9", ], [b"accept-encoding", b"gzip, deflate, br"], + [b"host", b"70ixmpl4fl.execute-api.us-east-2.amazonaws.com"], + [ + b"user-agent", + b"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + b"AppleWebKit/537.36 (KHTML, like Gecko) " + b"Chrome/80.0.3987.132 " + b"Safari/537.36", + ], + [b"x-amzn-trace-id", b"Root=1-5e66d96f-7491f09xmpl79d18acf3d050"], ], "http_version": "1.1", "method": "GET", @@ -120,7 +137,7 @@ def test_aws_api_gateway_scope_basic(): "raw_path": None, "root_path": "", "scheme": "https", - "server": ("mangum", 80), + "server": ("70ixmpl4fl.execute-api.us-east-2.amazonaws.com", 80), "type": "http", } @@ -184,7 +201,7 @@ def test_aws_api_gateway_scope_real( method, path, multi_value_query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsApiGateway(event, example_context) + handler = AwsApiGateway(event, example_context, "/") scope_path = path if scope_path == "": diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_aws_http_gateway.py index 4e87f9df..bdedfc99 100644 --- a/tests/handlers/test_aws_http_gateway.py +++ b/tests/handlers/test_aws_http_gateway.py @@ -194,7 +194,7 @@ def test_aws_http_gateway_scope_basic_v1(): "isBase64Encoded": False, } example_context = {} - handler = AwsHttpGateway(example_event, example_context) + handler = AwsHttpGateway(example_event, example_context, "/") assert type(handler.body) == bytes assert handler.request.scope == { @@ -226,7 +226,7 @@ def test_aws_http_gateway_scope_v1_only_non_multi_headers(): ) del example_event["multiValueQueryStringParameters"] example_context = {} - handler = AwsHttpGateway(example_event, example_context) + handler = AwsHttpGateway(example_event, example_context, "/") assert handler.request.scope["query_string"] == b"hello=world" @@ -240,7 +240,7 @@ def test_aws_http_gateway_scope_v1_no_headers(): del example_event["multiValueQueryStringParameters"] del example_event["queryStringParameters"] example_context = {} - handler = AwsHttpGateway(example_event, example_context) + handler = AwsHttpGateway(example_event, example_context, "/") assert handler.request.scope["query_string"] == b"" @@ -298,7 +298,7 @@ def test_aws_http_gateway_scope_basic_v2(): "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, } example_context = {} - handler = AwsHttpGateway(example_event, example_context) + handler = AwsHttpGateway(example_event, example_context, "/") assert type(handler.body) == bytes assert handler.request.scope == { @@ -334,7 +334,7 @@ def test_aws_http_gateway_scope_bad_version(): example_event = get_mock_aws_http_gateway_event_v2("GET", "/test", {}, None, False) example_event["version"] = "9001.1" example_context = {} - handler = AwsHttpGateway(example_event, example_context) + handler = AwsHttpGateway(example_event, example_context, "/") with pytest.raises(RuntimeError): handler.request.scope @@ -371,7 +371,7 @@ def test_aws_http_gateway_scope_real_v1( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsHttpGateway(event, example_context) + handler = AwsHttpGateway(event, example_context, "/") scope_path = path if scope_path == "": @@ -438,7 +438,7 @@ def test_aws_http_gateway_scope_real_v2( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsHttpGateway(event, example_context) + handler = AwsHttpGateway(event, example_context, "/") scope_path = path if scope_path == "": @@ -605,6 +605,5 @@ async def app(scope, receive, send): "statusCode": 200, "isBase64Encoded": res_base64_encoded, "headers": {"content-type": content_type.decode()}, - "multiValueHeaders": {}, "body": res_body, } diff --git a/tests/test_adapter.py b/tests/test_adapter.py index c2cffa1a..6a014b34 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -11,6 +11,7 @@ async def app(scope, receive, send): def test_default_settings(): handler = Mangum(app) assert handler.lifespan == "auto" + assert handler.api_gateway_base_path == "/" @pytest.mark.parametrize( diff --git a/tests/test_http.py b/tests/test_http.py index c82efd8c..4328aba3 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -253,7 +253,7 @@ async def app(scope, receive, send): @pytest.mark.parametrize( - "mock_http_api_event", + "mock_http_api_event_v2", [ (["GET", None, None, ""]), (["GET", None, {"name": ["me"]}, "name=me"]), @@ -267,9 +267,9 @@ async def app(scope, receive, send): ] ), ], - indirect=["mock_http_api_event"], + indirect=["mock_http_api_event_v2"], ) -def test_set_cookies(mock_http_api_event) -> None: +def test_set_cookies_v2(mock_http_api_event_v2) -> None: async def app(scope, receive, send): assert scope == { "asgi": {"version": "3.0"}, @@ -279,7 +279,7 @@ async def app(scope, receive, send): "version": "2.0", "routeKey": "$default", "rawPath": "/my/path", - "rawQueryString": mock_http_api_event["rawQueryString"], + "rawQueryString": mock_http_api_event_v2["rawQueryString"], "cookies": ["cookie1", "cookie2"], "headers": { "accept-encoding": "gzip,deflate", @@ -287,7 +287,9 @@ async def app(scope, receive, send): "x-forwarded-proto": "https", "host": "test.execute-api.us-west-2.amazonaws.com", }, - "queryStringParameters": mock_http_api_event["queryStringParameters"], + "queryStringParameters": mock_http_api_event_v2[ + "queryStringParameters" + ], "requestContext": { "accountId": "123456789012", "apiId": "api-id", @@ -331,7 +333,127 @@ async def app(scope, receive, send): "http_version": "1.1", "method": "GET", "path": "/my/path", - "query_string": mock_http_api_event["rawQueryString"].encode(), + "query_string": mock_http_api_event_v2["rawQueryString"].encode(), + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"set-cookie", b"cookie1=cookie1; Secure"], + [b"set-cookie", b"cookie2=cookie2; Secure"], + [b"multivalue", b"foo"], + [b"multivalue", b"bar"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + handler = Mangum(app, lifespan="off") + response = handler(mock_http_api_event_v2, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": { + "content-type": "text/plain; charset=utf-8", + "multivalue": "foo,bar", + }, + "cookies": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], + "body": "Hello, world!", + } + + +@pytest.mark.parametrize( + "mock_http_api_event_v1", + [ + (["GET", None, None, ""]), + (["GET", None, {"name": ["me"]}, "name=me"]), + (["GET", None, {"name": ["me", "you"]}, "name=me&name=you"]), + ( + [ + "GET", + None, + {"name": ["me", "you"], "pet": ["dog"]}, + "name=me&name=you&pet=dog", + ] + ), + ], + indirect=["mock_http_api_event_v1"], +) +def test_set_cookies_v1(mock_http_api_event_v1) -> None: + async def app(scope, receive, send): + assert scope == { + "asgi": {"version": "3.0"}, + "aws.eventType": "AWS_HTTP_GATEWAY", + "aws.context": {}, + "aws.event": { + "version": "1.0", + "routeKey": "$default", + "rawPath": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "rawQueryString": mock_http_api_event_v1["rawQueryString"], + "cookies": ["cookie1", "cookie2"], + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "queryStringParameters": mock_http_api_event_v1[ + "queryStringParameters" + ], + "multiValueQueryStringParameters": mock_http_api_event_v1[ + "multiValueQueryStringParameters" + ], + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "protocol": "HTTP/1.1", + "sourceIp": "192.168.100.1", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1_583_348_638_390, + }, + "body": None, + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": False, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2", + }, + }, + "client": (None, 0), + "headers": [ + [b"accept-encoding", b"gzip,deflate"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-proto", b"https"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + ], + "http_version": "1.1", + "method": "GET", + "path": "/my/path", + "query_string": mock_http_api_event_v1["rawQueryString"].encode(), "raw_path": None, "root_path": "", "scheme": "https", @@ -353,7 +475,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_api_event, {}) + response = handler(mock_http_api_event_v1, {}) assert response == { "statusCode": 200, "isBase64Encoded": False, diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index af02b651..7114c24c 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -8,11 +8,10 @@ from mangum import Mangum from mangum.exceptions import LifespanFailure -# One (or more) of Quart's dependencies does not support Python 3.8, ignore this case. -IS_PY38 = sys.version_info[:2] == (3, 8) +# Quart no longer support python3.6. IS_PY36 = sys.version_info[:2] == (3, 6) -if not (IS_PY38 or IS_PY36): +if not IS_PY36: from quart import Quart else: Quart = None @@ -269,9 +268,6 @@ def homepage(request): } -@pytest.mark.skipif( - IS_PY38, reason="One (or more) of Quart's dependencies does not support Python 3.8." -) @pytest.mark.skipif(IS_PY36, reason="Quart does not support Python 3.6.") @pytest.mark.parametrize( "mock_aws_api_gateway_event", [["GET", None, None]], indirect=True