Skip to content

Commit

Permalink
✨ Allows Flama application to decide which schema lib to use
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent ae6e8cf commit e20a3f9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 19 deletions.
28 changes: 27 additions & 1 deletion flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flama.components import Component, Components
from flama.http import Request, Response
from flama.modules import Module
from flama.routing import BaseRoute, Mount
from flama.routing import BaseRoute, Mount, WebSocketRoute

__all__ = ["Flama"]

Expand All @@ -50,6 +50,7 @@ def __init__(
schema: typing.Optional[str] = "/schema/",
docs: typing.Optional[str] = "/docs/",
redoc: typing.Optional[str] = None,
schema_library: typing.Optional[str] = None,
*args,
**kwargs
) -> None:
Expand Down Expand Up @@ -93,6 +94,9 @@ def __init__(
},
)

# Setup schema library
self.modules.schema.set_schema_library(schema_library) # type: ignore[attr-defined]

# Reference to paginator from within app
self.paginator = paginator

Expand All @@ -102,6 +106,28 @@ def __getattr__(self, item: str) -> "Module":
except KeyError:
return None # type: ignore[return-value]

def add_route( # type: ignore[override]
self,
path: typing.Optional[str] = None,
endpoint: typing.Optional[typing.Callable] = None,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
route: typing.Optional["BaseRoute"] = None,
) -> None: # pragma: no cover
self.router.add_route(
path, endpoint, methods=methods, name=name, include_in_schema=include_in_schema, route=route
)

def add_websocket_route( # type: ignore[override]
self,
path: typing.Optional[str] = None,
endpoint: typing.Optional[typing.Callable] = None,
name: typing.Optional[str] = None,
route: typing.Optional["WebSocketRoute"] = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, endpoint, name=name, route=route)

@property
def injector(self) -> Injector:
return Injector(self.components)
Expand Down
33 changes: 26 additions & 7 deletions flama/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,27 +276,33 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None:

def add_route(
self,
path: str,
endpoint: typing.Callable,
path: typing.Optional[str] = None,
endpoint: typing.Optional[typing.Callable] = None,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
route: BaseRoute = None,
):
try:
main_app = self.main_app
except AttributeError:
main_app = None

self.routes.append(
Route(
if path is not None and endpoint is not None:
route = Route(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
main_app=main_app,
)
)
elif route is not None:
route.main_app = main_app
else:
raise ValueError("Either 'path' and 'endpoint' or 'route' variables are needed")

self.routes.append(route)

def route(
self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True
Expand All @@ -307,13 +313,26 @@ def decorator(func: typing.Callable) -> typing.Callable:

return decorator

def add_websocket_route(self, path: str, endpoint: typing.Callable, name: str = None):
def add_websocket_route(
self,
path: typing.Optional[str] = None,
endpoint: typing.Optional[typing.Callable] = None,
name: str = None,
route: typing.Optional[WebSocketRoute] = None,
):
try:
main_app = self.main_app
except AttributeError:
main_app = None

self.routes.append(WebSocketRoute(path, endpoint=endpoint, name=name, main_app=main_app))
if path is not None and endpoint is not None:
route = WebSocketRoute(path, endpoint=endpoint, name=name, main_app=main_app)
elif route is not None:
route.main_app = main_app
else:
raise ValueError("Either 'path' and 'endpoint' or 'route' variables are needed")

self.routes.append(route)

def websocket_route(self, path: str, name: str = None) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
Expand Down
9 changes: 6 additions & 3 deletions flama/schemas/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
def schema_view():
return OpenAPIResponse(self.schema)

self.app.add_route(path=schema_url, route=schema_view, methods=["GET"], include_in_schema=False)
self.app.add_route(schema_url, schema_view, methods=["GET"], include_in_schema=False)

# Adds swagger ui endpoint
if docs:
Expand All @@ -58,7 +58,7 @@ def swagger_ui() -> HTMLResponse:

return HTMLResponse(content)

self.app.add_route(path=docs_url, route=swagger_ui, methods=["GET"], include_in_schema=False)
self.app.add_route(docs_url, swagger_ui, methods=["GET"], include_in_schema=False)

# Adds redoc endpoint
if redoc:
Expand All @@ -70,7 +70,7 @@ def redoc_view() -> HTMLResponse:

return HTMLResponse(content)

self.app.add_route(path=redoc_url, route=redoc_view, methods=["GET"], include_in_schema=False)
self.app.add_route(redoc_url, redoc_view, methods=["GET"], include_in_schema=False)

def register_schema(self, name: str, schema):
self.schemas[name] = schema
Expand All @@ -85,3 +85,6 @@ def schema_generator(self) -> SchemaGenerator:
@property
def schema(self) -> typing.Dict[str, typing.Any]:
return self.schema_generator.get_api_schema(self.app.routes)

def set_schema_library(self, library: str):
schemas._module.setup(library)
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typesystem
from faker import Faker

from flama import Flama, schemas
from flama import Flama
from flama.sqlalchemy import metadata
from flama.testclient import TestClient

Expand Down Expand Up @@ -88,8 +88,6 @@ def clear_metadata():
params=[pytest.param("typesystem", id="typesystem"), pytest.param("marshmallow", id="marshmallow")],
)
def app(request):
schemas._module.setup(request.param)

return Flama(
components=[],
title="Foo",
Expand All @@ -99,6 +97,7 @@ def app(request):
docs="/docs/",
redoc="/redoc/",
sqlalchemy_database="sqlite+aiosqlite://",
schema_library=request.param,
)


Expand Down
6 changes: 5 additions & 1 deletion tests/schemas/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def on_receive(self, websocket: websockets.WebSocket, data: websockets.Data) ->

@pytest.fixture(autouse=True)
def app(self, app, component, route, endpoint, websocket):
return Flama(routes=[route, endpoint, websocket], components=[component()], schema=None, docs=None)
app.add_component(component())
app.add_route(route=route)
app.add_route(route=endpoint)
app.add_websocket_route(route=websocket)
return app

def test_inspect_parameters_from_handler(self, route, app, foo_schema):
expected_parameters = {
Expand Down
19 changes: 15 additions & 4 deletions tests/schemas/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,34 @@
import pytest
import typesystem

from flama import schemas
from flama import Flama


class TestCaseSetup:
def test_setup_default(self):
schemas._module.setup()
Flama()

from flama import schemas

assert schemas.lib == typesystem

def test_setup_typesystem(self):
schemas._module.setup("typesystem")
Flama(schema_library="typesystem")

from flama import schemas

assert schemas.lib == typesystem

def test_setup_marshmallow(self):
schemas._module.setup("marshmallow")
Flama(schema_library="marshmallow")

from flama import schemas

assert schemas.lib == marshmallow

def test_setup_no_lib_installed(self):
from flama import schemas

with patch("flama.schemas.Module.available", PropertyMock(return_value=iter(()))), pytest.raises(
AssertionError,
match="No schema library is installed. Install one of your preference following instructions from: "
Expand Down

0 comments on commit e20a3f9

Please sign in to comment.