Skip to content

Commit

Permalink
Handle filestream json objects in multipart/form-data requests
Browse files Browse the repository at this point in the history
  • Loading branch information
cgearing committed Jan 12, 2024
1 parent 7b123da commit 7ab5ad8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
17 changes: 13 additions & 4 deletions flask_pydantic_spec/flask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,28 @@ def request_validation(
else:
parsed_body = request.get_json(silent=True) or {}
elif request.content_type and "multipart/form-data" in request.content_type:
parsed_body = parse_multi_dict(request.form) if request.form else {}
# It's possible there is a binary json object in the files - iterate through and find it
parsed_body = {}
for key, value in request.files.items():
if value.mimetype == "application/json":
parsed_body[key] = json.loads(value.stream.read().decode(encoding="utf-8"))
# Finally, find any JSON objects in the form and add them to the body
parsed_body.update(parse_multi_dict(request.form) or {})
else:
parsed_body = request.get_data() or {}

req_headers: Optional[Headers] = request.headers or None
req_cookies: Optional[Mapping[str, str]] = request.cookies or None
setattr(
request,
"context",
Context(
query=query.parse_obj(req_query) if query else None,
body=getattr(body, "model").parse_obj(parsed_body)
if body and getattr(body, "model")
else None,
body=(
getattr(body, "model").parse_obj(parsed_body)
if body and getattr(body, "model")
else None
),
headers=headers.parse_obj(req_headers or {}) if headers else None,
cookies=cookies.parse_obj(req_cookies or {}) if cookies else None,
),
Expand Down
31 changes: 21 additions & 10 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from io import BytesIO
from random import randint
import gzip
from typing import Union

import pytest
import json
from flask import Flask, jsonify, request
from werkzeug.datastructures import FileStorage
from werkzeug.test import Client

from flask_pydantic_spec.types import Response, MultipartFormRequest
from flask_pydantic_spec import FlaskPydanticSpec
Expand Down Expand Up @@ -111,7 +114,7 @@ def client(request):


@pytest.mark.parametrize("client", [422], indirect=True)
def test_flask_validate(client):
def test_flask_validate(client: Client):
resp = client.get("/ping")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"
Expand Down Expand Up @@ -158,23 +161,31 @@ def test_flask_validate(client):


@pytest.mark.parametrize("client", [422], indirect=True)
def test_sending_file(client):
@pytest.mark.parametrize(
"data",
[
FileStorage(
BytesIO(json.dumps({"type": "foo", "created_at": str(datetime.now().date())}).encode()),
),
json.dumps({"type": "foo", "created_at": str(datetime.now().date())}),
],
)
def test_sending_file(client: Client, data: Union[FileStorage, str]):
file = FileStorage(BytesIO(b"abcde"), filename="test.jpg", name="test.jpg")
resp = client.post(
"/api/file",
data={
"file": file,
"file_name": "another_test.jpg",
"data": json.dumps({"type": "foo", "created_at": str(datetime.now().date())}),
"data": data,
},
content_type="multipart/form-data",
)
assert resp.status_code == 200
assert resp.json["name"] == "another_test.jpg"


@pytest.mark.parametrize("client", [422], indirect=True)
def test_query_params(client):
def test_query_params(client: Client):
resp = client.get("api/user?name=james&name=bethany&name=claire")
assert resp.status_code == 200
assert len(resp.json["data"]) == 2
Expand All @@ -189,15 +200,15 @@ def test_query_params(client):


@pytest.mark.parametrize("client", [200], indirect=True)
def test_flask_skip_validation(client):
def test_flask_skip_validation(client: Client):
resp = client.get("api/group/test")
assert resp.status_code == 200
assert resp.json["name"] == "test"
assert resp.json["score"] == ["a", "b", "c", "d", "e"]


@pytest.mark.parametrize("client", [422], indirect=True)
def test_flask_doc(client):
def test_flask_doc(client: Client):
resp = client.get("/apidoc/openapi.json")
assert resp.json == api.spec

Expand All @@ -211,7 +222,7 @@ def test_flask_doc(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_validate_with_alternative_code(client):
def test_flask_validate_with_alternative_code(client: Client):
resp = client.get("/ping")
assert resp.status_code == 400
assert resp.headers.get("X-Error") == "Validation Error"
Expand All @@ -222,7 +233,7 @@ def test_flask_validate_with_alternative_code(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_post_gzip(client):
def test_flask_post_gzip(client: Client):
body = dict(name="flask", limit=10)
compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8"))

Expand All @@ -240,7 +251,7 @@ def test_flask_post_gzip(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_post_gzip_failure(client):
def test_flask_post_gzip_failure(client: Client):
body = dict(name="flask")
compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8"))

Expand Down

0 comments on commit 7ab5ad8

Please sign in to comment.