diff --git a/flama/applications.py b/flama/applications.py index 5c464f50..beda6485 100644 --- a/flama/applications.py +++ b/flama/applications.py @@ -24,7 +24,7 @@ from flama.components import Component, Components from flama.http import Request, Response from flama.modules import Module - from flama.routing import BaseRoute, Mount + from flama.routing import BaseRoute, Mount, WebSocketRoute __all__ = ["Flama"] @@ -50,6 +50,7 @@ def __init__( schema: typing.Optional[str] = "/schema/", docs: typing.Optional[str] = "/docs/", redoc: typing.Optional[str] = None, + schema_library: typing.Optional[str] = None, *args, **kwargs ) -> None: @@ -93,6 +94,9 @@ def __init__( }, ) + # Setup schema library + self.modules.schema.set_schema_library(schema_library) # type: ignore[attr-defined] + # Reference to paginator from within app self.paginator = paginator @@ -102,6 +106,28 @@ def __getattr__(self, item: str) -> "Module": except KeyError: return None # type: ignore[return-value] + def add_route( # type: ignore[override] + self, + path: typing.Optional[str] = None, + endpoint: typing.Optional[typing.Callable] = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, + include_in_schema: bool = True, + route: typing.Optional["BaseRoute"] = None, + ) -> None: # pragma: no cover + self.router.add_route( + path, endpoint, methods=methods, name=name, include_in_schema=include_in_schema, route=route + ) + + def add_websocket_route( # type: ignore[override] + self, + path: typing.Optional[str] = None, + endpoint: typing.Optional[typing.Callable] = None, + name: typing.Optional[str] = None, + route: typing.Optional["WebSocketRoute"] = None, + ) -> None: # pragma: no cover + self.router.add_websocket_route(path, endpoint, name=name, route=route) + @property def injector(self) -> Injector: return Injector(self.components) diff --git a/flama/routing.py b/flama/routing.py index dc5baead..f6a4d25d 100644 --- a/flama/routing.py +++ b/flama/routing.py @@ -276,19 +276,20 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: def add_route( self, - path: str, - endpoint: typing.Callable, + path: typing.Optional[str] = None, + endpoint: typing.Optional[typing.Callable] = None, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, + route: BaseRoute = None, ): try: main_app = self.main_app except AttributeError: main_app = None - self.routes.append( - Route( + if path is not None and endpoint is not None: + route = Route( path, endpoint=endpoint, methods=methods, @@ -296,7 +297,12 @@ def add_route( include_in_schema=include_in_schema, main_app=main_app, ) - ) + elif route is not None: + route.main_app = main_app + else: + raise ValueError("Either 'path' and 'endpoint' or 'route' variables are needed") + + self.routes.append(route) def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True @@ -307,13 +313,26 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator - def add_websocket_route(self, path: str, endpoint: typing.Callable, name: str = None): + def add_websocket_route( + self, + path: typing.Optional[str] = None, + endpoint: typing.Optional[typing.Callable] = None, + name: str = None, + route: typing.Optional[WebSocketRoute] = None, + ): try: main_app = self.main_app except AttributeError: main_app = None - self.routes.append(WebSocketRoute(path, endpoint=endpoint, name=name, main_app=main_app)) + if path is not None and endpoint is not None: + route = WebSocketRoute(path, endpoint=endpoint, name=name, main_app=main_app) + elif route is not None: + route.main_app = main_app + else: + raise ValueError("Either 'path' and 'endpoint' or 'route' variables are needed") + + self.routes.append(route) def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: diff --git a/flama/schemas/modules.py b/flama/schemas/modules.py index 0f2b4e57..94984c4b 100644 --- a/flama/schemas/modules.py +++ b/flama/schemas/modules.py @@ -46,7 +46,7 @@ def __init__( def schema_view(): return OpenAPIResponse(self.schema) - self.app.add_route(path=schema_url, route=schema_view, methods=["GET"], include_in_schema=False) + self.app.add_route(schema_url, schema_view, methods=["GET"], include_in_schema=False) # Adds swagger ui endpoint if docs: @@ -58,7 +58,7 @@ def swagger_ui() -> HTMLResponse: return HTMLResponse(content) - self.app.add_route(path=docs_url, route=swagger_ui, methods=["GET"], include_in_schema=False) + self.app.add_route(docs_url, swagger_ui, methods=["GET"], include_in_schema=False) # Adds redoc endpoint if redoc: @@ -70,7 +70,7 @@ def redoc_view() -> HTMLResponse: return HTMLResponse(content) - self.app.add_route(path=redoc_url, route=redoc_view, methods=["GET"], include_in_schema=False) + self.app.add_route(redoc_url, redoc_view, methods=["GET"], include_in_schema=False) def register_schema(self, name: str, schema): self.schemas[name] = schema @@ -85,3 +85,6 @@ def schema_generator(self) -> SchemaGenerator: @property def schema(self) -> typing.Dict[str, typing.Any]: return self.schema_generator.get_api_schema(self.app.routes) + + def set_schema_library(self, library: str): + schemas._module.setup(library) diff --git a/tests/conftest.py b/tests/conftest.py index 2483a869..77e1625e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import typesystem from faker import Faker -from flama import Flama, schemas +from flama import Flama from flama.sqlalchemy import metadata from flama.testclient import TestClient @@ -88,8 +88,6 @@ def clear_metadata(): params=[pytest.param("typesystem", id="typesystem"), pytest.param("marshmallow", id="marshmallow")], ) def app(request): - schemas._module.setup(request.param) - return Flama( components=[], title="Foo", @@ -99,6 +97,7 @@ def app(request): docs="/docs/", redoc="/redoc/", sqlalchemy_database="sqlite+aiosqlite://", + schema_library=request.param, ) diff --git a/tests/schemas/test_routing.py b/tests/schemas/test_routing.py index 61f61670..a341bfc1 100644 --- a/tests/schemas/test_routing.py +++ b/tests/schemas/test_routing.py @@ -68,7 +68,11 @@ def on_receive(self, websocket: websockets.WebSocket, data: websockets.Data) -> @pytest.fixture(autouse=True) def app(self, app, component, route, endpoint, websocket): - return Flama(routes=[route, endpoint, websocket], components=[component()], schema=None, docs=None) + app.add_component(component()) + app.add_route(route=route) + app.add_route(route=endpoint) + app.add_websocket_route(route=websocket) + return app def test_inspect_parameters_from_handler(self, route, app, foo_schema): expected_parameters = { diff --git a/tests/schemas/test_setup.py b/tests/schemas/test_setup.py index 2be7e038..23d0b6c7 100644 --- a/tests/schemas/test_setup.py +++ b/tests/schemas/test_setup.py @@ -4,23 +4,34 @@ import pytest import typesystem -from flama import schemas +from flama import Flama class TestCaseSetup: def test_setup_default(self): - schemas._module.setup() + Flama() + + from flama import schemas + assert schemas.lib == typesystem def test_setup_typesystem(self): - schemas._module.setup("typesystem") + Flama(schema_library="typesystem") + + from flama import schemas + assert schemas.lib == typesystem def test_setup_marshmallow(self): - schemas._module.setup("marshmallow") + Flama(schema_library="marshmallow") + + from flama import schemas + assert schemas.lib == marshmallow def test_setup_no_lib_installed(self): + from flama import schemas + with patch("flama.schemas.Module.available", PropertyMock(return_value=iter(()))), pytest.raises( AssertionError, match="No schema library is installed. Install one of your preference following instructions from: "