Skip to content

Commit

Permalink
✨ ASGI Cookies component
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent 4f47f49 commit 0695c77
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 3 deletions.
17 changes: 14 additions & 3 deletions flama/asgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from http.cookies import SimpleCookie
from urllib.parse import parse_qsl

from flama import types
from flama import http, types
from flama.injection.components import Component, Components

__all__ = [
Expand Down Expand Up @@ -66,8 +67,17 @@ def resolve(self, scope: types.Scope) -> types.QueryParams:


class HeadersComponent(Component):
def resolve(self, scope: types.Scope) -> types.Headers:
return types.Headers(scope=scope)
def resolve(self, request: http.Request) -> types.Headers:
return request.headers


class CookiesComponent(Component):
def resolve(self, headers: types.Headers) -> types.Cookies:
cookie = SimpleCookie()
cookie.load(headers.get("cookie", ""))
return types.Cookies(
{str(name): {str(k): str(v) for k, v in morsel.items()} for name, morsel in cookie.items()}
)


class BodyComponent(Component):
Expand Down Expand Up @@ -95,6 +105,7 @@ async def resolve(self, receive: types.Receive) -> types.Body:
QueryStringComponent(),
QueryParamsComponent(),
HeadersComponent(),
CookiesComponent(),
BodyComponent(),
]
)
2 changes: 2 additions & 0 deletions flama/types/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"URL",
"Headers",
"MutableHeaders",
"Cookies",
"QueryParams",
"PARAMETERS_TYPES",
]
Expand All @@ -41,6 +42,7 @@
RequestData = t.NewType("RequestData", t.Dict[str, t.Any])
Headers = starlette.datastructures.Headers
MutableHeaders = starlette.datastructures.MutableHeaders
Cookies = t.NewType("Cookies", t.Dict[str, t.Dict[str, str]])
QueryParams = starlette.datastructures.QueryParams

PARAMETERS_TYPES: t.Dict[t.Type, t.Type] = {
Expand Down
99 changes: 99 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import http.cookiejar

import pytest

from flama import types
Expand Down Expand Up @@ -301,6 +303,103 @@ async def test_headers(self, client, path, method, request_kwargs, expected):
assert response_json == expected


class TestCaseCookiesComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
@app.route("/cookies/", methods=["GET", "POST"])
def get_cookies(cookies: types.Cookies):
return {"cookies": dict(cookies)}

@pytest.mark.parametrize(
["path", "method", "cookies", "expected"],
[
pytest.param(
"http://example.com/cookies/",
"get",
{},
{"cookies": {}},
id="default",
),
pytest.param(
"http://example.com/cookies/",
"get",
[
http.cookiejar.Cookie(
version=0,
name="foo",
value="bar",
port=None,
port_specified=False,
domain="",
domain_specified=False,
domain_initial_dot=False,
path="/",
path_specified=True,
secure=False,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={"HttpOnly": ""},
rfc2109=False,
)
],
{
"cookies": {
"foo": {
"expires": "",
"path": "",
"comment": "",
"domain": "",
"max-age": "",
"secure": "",
"httponly": "",
"version": "",
"samesite": "",
}
}
},
id="cookie",
),
pytest.param(
"http://example.com/cookies/",
"get",
[
http.cookiejar.Cookie(
version=0,
name="foo",
value="bar",
port=None,
port_specified=False,
domain="",
domain_specified=False,
domain_initial_dot=False,
path="/",
path_specified=True,
secure=True,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={"HttpOnly": "true"},
rfc2109=False,
)
],
{"cookies": {}}, # Cannot get cookie because secure and no https
id="cookie_secure",
),
],
)
async def test_cookies(self, client, path, method, cookies, expected):
cookies_jar = http.cookiejar.CookieJar()
for cookie in cookies:
cookies_jar.set_cookie(cookie)
response = await client.request(method, path, cookies=cookies_jar)
response_json = response.json()

assert response_json == expected


class TestCaseBodyComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
Expand Down

0 comments on commit 0695c77

Please sign in to comment.