From a66d5ee55b3848b65c497d56887867b199ddba1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Sun, 8 Oct 2023 10:14:42 +0200 Subject: [PATCH] :sparkles: Enhanced actions for DDD repositories (#124) --- flama/client.py | 2 +- flama/ddd/exceptions.py | 6 +- flama/ddd/repositories.py | 245 ++++++++++++++++++++------------- flama/resources/crud.py | 24 ++-- tests/ddd/test_repositories.py | 93 +++++++------ tests/resources/test_crud.py | 2 +- 6 files changed, 222 insertions(+), 150 deletions(-) diff --git a/flama/client.py b/flama/client.py index 42ab49d5..fcc05a60 100644 --- a/flama/client.py +++ b/flama/client.py @@ -102,7 +102,7 @@ def __init__( if models: app = Flama() if not app else app - for (name, url, path) in models: + for name, url, path in models: app.models.add_model(url, path, name) self.models = {m[0]: m[1] for m in models or {}} diff --git a/flama/ddd/exceptions.py b/flama/ddd/exceptions.py index fa2583a0..1b8d4d2a 100644 --- a/flama/ddd/exceptions.py +++ b/flama/ddd/exceptions.py @@ -1,4 +1,4 @@ -__all__ = ["RepositoryException", "IntegrityError", "NotFoundError"] +__all__ = ["RepositoryException", "IntegrityError", "NotFoundError", "MultipleRecordsError"] class RepositoryException(Exception): @@ -11,3 +11,7 @@ class IntegrityError(RepositoryException): class NotFoundError(RepositoryException): ... + + +class MultipleRecordsError(RepositoryException): + ... diff --git a/flama/ddd/repositories.py b/flama/ddd/repositories.py index f111a976..9338cf70 100644 --- a/flama/ddd/repositories.py +++ b/flama/ddd/repositories.py @@ -46,95 +46,95 @@ def __eq__(self, other): and self.table == other.table ) - @property - def primary_key(self) -> sqlalchemy.Column: - """Returns the primary key of the model. + async def create(self, *data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.List[t.Tuple[t.Any, ...]]: + """Creates new elements in the table. - :return: sqlalchemy.Column: The primary key of the model. - :raises: exceptions.IntegrityError: If the model has a composed primary key. - """ - - model_pk_columns = list(sqlalchemy.inspect(self.table).primary_key.columns.values()) - - if len(model_pk_columns) != 1: - raise exceptions.IntegrityError("Composed primary keys are not supported") - - return model_pk_columns[0] - - async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]: - """Creates a new element in the repository. - - If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns + If the element already exists, it raises an `IntegrityError`. If the element is created, it returns the primary key of the element. :param data: The data to create the element. :return: The primary key of the created element. - :raises: exceptions.IntegrityError: If the element already exists. + :raises IntegrityError: If the element already exists or cannot be inserted. """ try: - result = await self._connection.execute(sqlalchemy.insert(self.table).values(**data)) + result = await self._connection.execute(sqlalchemy.insert(self.table), data) except sqlalchemy.exc.IntegrityError as e: raise exceptions.IntegrityError(str(e)) - return tuple(result.inserted_primary_key) if result.inserted_primary_key else None + return [tuple(x) for x in result.inserted_primary_key_rows] - async def retrieve(self, id: t.Any) -> types.Schema: - """Retrieves an element from the repository. + async def retrieve(self, *clauses, **filters) -> types.Schema: + """Retrieves an element from the table. + + If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a + `MultipleRecordsError`. If the element is found, it returns the element. - If the element does not exist, it raises a `NotFoundError`. + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements + using exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` :param id: The primary key of the element. :return: The element. - :raises: exceptions.NotFoundError: If the element does not exist. + :raises NotFoundError: If the element does not exist. + :raises MultipleRecordsError: If more than one element is found. """ - element = ( - await self._connection.execute( - sqlalchemy.select(self.table).where(self.table.c[self.primary_key.name] == id) - ) - ).first() + query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters) - if element is None: - raise exceptions.NotFoundError(str(id)) + try: + element = (await self._connection.execute(query)).one() + except sqlalchemy.exc.NoResultFound: + raise exceptions.NotFoundError() + except sqlalchemy.exc.MultipleResultsFound: + raise exceptions.MultipleRecordsError() return types.Schema(element._asdict()) - async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema: - """Updates an element in the repository. + async def update(self, data: t.Union[t.Dict[str, t.Any], types.Schema], *clauses, **filters) -> int: + """Updates elements in the table. + + Using clauses and filters, it filters the elements to update. If no clauses or filters are given, it updates + all the elements in the table. - If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated - element. :param id: The primary key of the element. :param data: The data to update the element. - :return: The updated element. - :raises: exceptions.NotFoundError: If the element does not exist. + :return: The number of elements updated. + :raises IntegrityError: If the element cannot be updated. """ - pk = self.primary_key - result = await self._connection.execute( - sqlalchemy.update(self.table).where(self.table.c[pk.name] == id).values(**data) - ) + query = self._filter_query(sqlalchemy.update(self.table), *clauses, **filters).values(**data) - if result.rowcount == 0: - raise exceptions.NotFoundError(id) + try: + result = await self._connection.execute(query) + except sqlalchemy.exc.IntegrityError: + raise exceptions.IntegrityError - return types.Schema({pk.name: id, **data}) + return result.rowcount - async def delete(self, id: t.Any) -> None: - """Deletes an element from the repository. + async def delete(self, *clauses, **filters) -> None: + """Delete elements from the table. - If the element does not exist, it raises a `NotFoundError`. + If no clauses or filters are given, it deletes all the elements in the repository. - :param id: The primary key of the element. - :raises: exceptions.NotFoundError: If the element does not exist. + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. + :raises NotFoundError: If the element does not exist. + :raises MultipleRecordsError: If more than one element is found. """ - result = await self._connection.execute( - sqlalchemy.delete(self.table).where(self.table.c[self.primary_key.name] == id) - ) + await self.retrieve(*clauses, **filters) - if result.rowcount == 0: - raise exceptions.NotFoundError(id) + query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters) - async def list(self, *clauses, **filters) -> t.List[types.Schema]: - """Lists all the elements in the repository. + await self._connection.execute(query) + + async def list(self, *clauses, **filters) -> t.AsyncIterable[types.Schema]: + """Lists all the elements in the table. If no elements are found, it returns an empty list. If no clauses or filters are given, it returns all the elements in the repository. @@ -147,26 +147,61 @@ async def list(self, *clauses, **filters) -> t.List[types.Schema]: :param clauses: Clauses to filter the elements. :param filters: Filters to filter the elements. - :return: The elements. + :return: Async iterable of the elements. """ - query = sqlalchemy.select(self.table) + query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters) - where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items()) - if where_clauses: - query = query.where(sqlalchemy.and_(*where_clauses)) + result = await self._connection.stream(query) + + async for row in result: + yield types.Schema(row._asdict()) + + async def drop(self, *clauses, **filters) -> int: + """Drops elements in the table. - return [types.Schema(row._asdict()) async for row in await self._connection.stream(query)] + Returns the number of elements dropped. If no clauses or filters are given, it deletes all the elements in the + table. - async def drop(self) -> int: - """Drops all the elements in the repository. + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. - Returns the number of elements dropped. + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. :return: The number of elements dropped. """ - result = await self._connection.execute(sqlalchemy.delete(self.table)) + query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters) + + result = await self._connection.execute(query) + return result.rowcount + def _filter_query(self, query, *clauses, **filters): + """Filters a query using clauses and filters. + + Returns the filtered query. If no clauses or filters are given, it returns the query without any applying + filter. + + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + + :param query: The query to filter. + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. + :return: The filtered query. + """ + where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items()) + + if where_clauses: + query = query.where(sqlalchemy.and_(*where_clauses)) + + return query + class SQLAlchemyTableRepository(SQLAlchemyRepository): _table: t.ClassVar[sqlalchemy.Table] @@ -178,53 +213,70 @@ def __init__(self, connection: "AsyncConnection"): def __eq__(self, other): return isinstance(other, SQLAlchemyTableRepository) and self._table == other._table and super().__eq__(other) - async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]: - """Creates a new element in the repository. + async def create(self, *data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.List[t.Tuple[t.Any, ...]]: + """Creates new elements in the repository. If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns the primary key of the element. :param data: The data to create the element. :return: The primary key of the created element. - :raises: exceptions.IntegrityError: If the element already exists. + :raises IntegrityError: If the element already exists or cannot be inserted. """ - return await self._table_manager.create(data) + return await self._table_manager.create(*data) - async def retrieve(self, id: t.Any) -> types.Schema: + async def retrieve(self, *clauses, **filters) -> types.Schema: """Retrieves an element from the repository. - If the element does not exist, it raises a `NotFoundError`. + If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a + `MultipleRecordsError`. If the element is found, it returns the element. - :param id: The primary key of the element. + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements + using exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. :return: The element. - :raises: exceptions.NotFoundError: If the element does not exist. + :raises NotFoundError: If the element does not exist. + :raises MultipleRecordsError: If more than one element is found. """ - return await self._table_manager.retrieve(id) + return await self._table_manager.retrieve(*clauses, **filters) - async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema: + async def update(self, data: t.Union[t.Dict[str, t.Any], types.Schema], *clauses, **filters) -> int: """Updates an element in the repository. If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated element. - :param id: The primary key of the element. :param data: The data to update the element. - :return: The updated element. - :raises: exceptions.NotFoundError: If the element does not exist. + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. + :return: The number of elements updated. + :raises IntegrityError: If the element cannot be updated. """ - return await self._table_manager.update(id, data) + return await self._table_manager.update(data, *clauses, **filters) - async def delete(self, id: t.Any) -> None: + async def delete(self, *clauses, **filters) -> None: """Deletes an element from the repository. - If the element does not exist, it raises a `NotFoundError`. + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements + using exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` :param id: The primary key of the element. - :raises: exceptions.NotFoundError: If the element does not exist. + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. + :raises NotFoundError: If the element does not exist. + :raises MultipleRecordsError: If more than one element is found. """ - return await self._table_manager.delete(id) + return await self._table_manager.delete(*clauses, **filters) - async def list(self, *clauses, **filters) -> t.List[types.Schema]: + def list(self, *clauses, **filters) -> t.AsyncIterable[types.Schema]: """Lists all the elements in the repository. Lists all the elements in the repository that match the clauses and filters. If no clauses or filters are given, @@ -238,15 +290,24 @@ async def list(self, *clauses, **filters) -> t.List[types.Schema]: :param clauses: Clauses to filter the elements. :param filters: Filters to filter the elements. - :return: The elements. + :return: Async iterable of the elements. """ - return await self._table_manager.list(*clauses, **filters) + return self._table_manager.list(*clauses, **filters) + + async def drop(self, *clauses, **filters) -> int: + """Drops elements in the repository. - async def drop(self) -> int: - """Drops all the elements in the repository. + Returns the number of elements dropped. If no clauses or filters are given, it deletes all the elements in the + repository. + + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. - Returns the number of elements dropped. + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. :return: The number of elements dropped. """ - return await self._table_manager.drop() + return await self._table_manager.drop(*clauses, **filters) diff --git a/flama/resources/crud.py b/flama/resources/crud.py index aa70f6b2..12c84c8b 100644 --- a/flama/resources/crud.py +++ b/flama/resources/crud.py @@ -44,7 +44,7 @@ async def create( return http.APIResponse( # type: ignore[return-value] schema=rest_schemas.output.schema, - content={**element, **dict(zip([x.name for x in self.model.primary_key], result or []))}, + content={**element, **dict(zip([x.name for x in self.model.primary_key], result[0] if result else []))}, status_code=201, ) @@ -82,7 +82,9 @@ async def retrieve( ) -> types.Schema[rest_schemas.output.schema]: try: async with worker: - return await worker.repositories[self._meta.name].retrieve(element_id) + return await worker.repositories[self._meta.name].retrieve( + **{rest_model.primary_key.name: element_id} + ) except ddd_exceptions.NotFoundError: raise exceptions.HTTPException(status_code=404) @@ -126,12 +128,18 @@ async def update( clean_element = types.Schema[rest_schemas.input.schema]( {k: v for k, v in schema.dump(element).items() if k != rest_model.primary_key.name} ) - try: - async with worker: - return await worker.repositories[self._meta.name].update(element_id, clean_element) - except ddd_exceptions.NotFoundError: + async with worker: + result = await worker.repositories[self._meta.name].update( + clean_element, **{rest_model.primary_key.name: element_id} + ) + + if result == 0: raise exceptions.HTTPException(status_code=404) + return types.Schema[rest_schemas.output.schema]( + {**clean_element, **{rest_model.primary_key.name: element_id}} + ) + update.__doc__ = f""" tags: - {verbose_name} @@ -160,7 +168,7 @@ def _add_delete( async def delete(self, worker: FlamaWorker, element_id: rest_model.primary_key.type): try: async with worker: - await worker.repositories[self._meta.name].delete(element_id) + await worker.repositories[self._meta.name].delete(**{rest_model.primary_key.name: element_id}) except ddd_exceptions.NotFoundError: raise exceptions.HTTPException(status_code=404) @@ -193,7 +201,7 @@ def _add_list( @resource_method("/", methods=["GET"], name=f"{name}-list", pagination="page_number") async def list(self, worker: FlamaWorker, **kwargs) -> types.Schema[rest_schemas.output.schema]: async with worker: - return await worker.repositories[self._meta.name].list() # type: ignore[return-value] + return [x async for x in worker.repositories[self._meta.name].list()] # type: ignore[return-value] list.__doc__ = f""" tags: diff --git a/tests/ddd/test_repositories.py b/tests/ddd/test_repositories.py index d2001f58..fff47e85 100644 --- a/tests/ddd/test_repositories.py +++ b/tests/ddd/test_repositories.py @@ -89,75 +89,67 @@ def test_eq(self, table, connection): assert SQLAlchemyTableManager(table, connection) == SQLAlchemyTableManager(table, connection) @pytest.mark.parametrize( - ["table", "result", "exception"], + ["data", "result", "exception"], ( - pytest.param("single", "id", None, id="single_pk"), + pytest.param([{"name": "foo"}], [(1,)], None, id="single"), pytest.param( - "composed", None, exceptions.IntegrityError("Composed primary keys are not supported"), id="composed_pk" + [{"name": "foo"}, {"name": "bar"}], + [(None,), (None,)], # SQlite doesn't allow to retrieve pk from bulk inserts + None, + id="multiple", ), - ), - indirect=["exception"], - ) - async def test_primary_key(self, table, result, exception, tables): - table_manager = SQLAlchemyTableManager(tables[table], Mock()) - - with exception: - assert table_manager.primary_key.name == result - - @pytest.mark.parametrize( - ["data", "result", "exception"], - ( - pytest.param({"name": "foo"}, (1,), None, id="ok"), - pytest.param({"name": None}, None, exceptions.IntegrityError, id="integrity_error"), + pytest.param([{"name": None}], None, exceptions.IntegrityError, id="integrity_error"), ), indirect=["exception"], ) async def test_create(self, table_manager, data, result, exception): with exception: - assert await table_manager.create(data) == result + assert await table_manager.create(*data) == result @pytest.mark.parametrize( - ["data", "result", "exception"], + ["clauses", "filters", "result", "exception"], ( - pytest.param(1, {"id": 1, "name": "foo"}, None, id="ok"), - pytest.param(2, None, exceptions.NotFoundError(1), id="not_found"), + pytest.param([], {"id": 1}, {"id": 1, "name": "foo"}, None, id="ok"), + pytest.param([], {"id": 0}, None, exceptions.NotFoundError(), id="not_found"), + pytest.param( + [lambda x: x.ilike("fo%")], {}, None, exceptions.MultipleRecordsError(), id="multiple_results" + ), ), indirect=["exception"], ) - async def test_retrieve(self, data, result, exception, table_manager): - await table_manager.create({"name": "foo"}) + async def test_retrieve(self, clauses, filters, result, exception, table, table_manager): + await table_manager.create({"name": "foo"}, {"name": "foo"}) with exception: - assert await table_manager.retrieve(data) == result + assert await table_manager.retrieve(*[c(table.c["name"]) for c in clauses], **filters) == result @pytest.mark.parametrize( - ["data", "result", "exception"], + ["clauses", "filters", "data", "result"], ( - pytest.param((1, {"name": "bar"}), {"id": 1, "name": "foo"}, None, id="ok"), - pytest.param((2, {"name": "bar"}), None, exceptions.NotFoundError(1), id="not_found"), + pytest.param([], {"id": 1}, {"name": "bar"}, 1, id="ok"), + pytest.param([], {"id": 0}, {"name": "bar"}, 0, id="not_found"), + pytest.param([lambda x: x.ilike("fo%")], {}, {"name": "bar"}, 2, id="multiple_results"), ), - indirect=["exception"], ) - async def test_update(self, data, result, exception, table_manager): - id_, data_ = data - await table_manager.create(data_) + async def test_update(self, clauses, filters, data, result, table, table_manager): + await table_manager.create({"name": "foo"}, {"name": "foo"}) - with exception: - assert await table_manager.update(id_, {"name": "foo"}) == result + assert await table_manager.update(data, *[c(table.c["name"]) for c in clauses], **filters) == result @pytest.mark.parametrize( - ["data", "exception"], + ["clauses", "filters", "exception"], ( - pytest.param(1, None, id="ok"), - pytest.param(2, exceptions.NotFoundError(1), id="not_found"), + pytest.param([], {"id": 1}, None, id="ok"), + pytest.param([], {"id": 0}, exceptions.NotFoundError(), id="not_found"), + pytest.param([lambda x: x.ilike("fo%")], {}, exceptions.MultipleRecordsError(), id="multiple_results"), ), indirect=["exception"], ) - async def test_delete(self, data, exception, table_manager): - await table_manager.create({"name": "foo"}) + async def test_delete(self, clauses, filters, exception, table, table_manager): + await table_manager.create({"name": "foo"}, {"name": "foo"}) with exception: - await table_manager.delete(data) + await table_manager.delete(*[c(table.c["name"]) for c in clauses], **filters) @pytest.mark.parametrize( ["clauses", "filters", "result"], @@ -168,19 +160,26 @@ async def test_delete(self, data, exception, table_manager): ), ) async def test_list(self, clauses, filters, result, table, table_manager): - await table_manager.create({"name": "foo"}) - await table_manager.create({"name": "bar"}) + await table_manager.create({"name": "foo"}, {"name": "bar"}) - r = await table_manager.list(*[c(table.c["name"]) for c in clauses], **filters) + r = [x async for x in table_manager.list(*[c(table.c["name"]) for c in clauses], **filters)] assert r == result - async def test_drop(self, table_manager): - await table_manager.create({"name": "foo"}) + @pytest.mark.parametrize( + ["clauses", "filters", "result"], + ( + pytest.param([], {}, 2, id="all"), + pytest.param([lambda x: x.ilike("fo%")], {}, 1, id="clauses"), + pytest.param([], {"name": "foo"}, 1, id="filters"), + ), + ) + async def test_drop(self, clauses, filters, result, table, table_manager): + await table_manager.create({"name": "foo"}, {"name": "bar"}) - result = await table_manager.drop() + r = await table_manager.drop(*[c(table.c["name"]) for c in clauses], **filters) - assert result == 1 + assert r == result class TestCaseSQLAlchemyTableRepository: @@ -253,7 +252,7 @@ async def test_list(self, repository, table_manager): clauses = [Mock(), Mock()] filters = {"foo": "bar"} - await repository.list(*clauses, **filters) + repository.list(*clauses, **filters) assert table_manager.list.call_args_list == [call(*clauses, **filters)] diff --git a/tests/resources/test_crud.py b/tests/resources/test_crud.py index 777a500d..09398cab 100644 --- a/tests/resources/test_crud.py +++ b/tests/resources/test_crud.py @@ -362,7 +362,7 @@ async def list( filters["name"] = name async with worker: - return await worker.repositories[self._meta.name].list(*clauses, **filters) + return [x async for x in worker.repositories[self._meta.name].list(*clauses, **filters)] return PuppyResource()