From 96229a14e764cb803239931eef1d777277f606c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Thu, 5 Oct 2023 15:18:56 +0200 Subject: [PATCH] :sparkles: Decouple table manager logic from repositories (#118) --- flama/ddd/repositories.py | 166 +++++++++++++++++++--- flama/resources/rest.py | 4 +- flama/resources/workers.py | 10 +- tests/ddd/test_repositories.py | 243 ++++++++++++++++++++++---------- tests/ddd/test_workers.py | 36 ++++- tests/resources/test_workers.py | 4 +- 6 files changed, 355 insertions(+), 108 deletions(-) diff --git a/flama/ddd/repositories.py b/flama/ddd/repositories.py index 783f7685..f111a976 100644 --- a/flama/ddd/repositories.py +++ b/flama/ddd/repositories.py @@ -15,24 +15,35 @@ from sqlalchemy.ext.asyncio import AsyncConnection -__all__ = ["AbstractRepository", "SQLAlchemyRepository"] +__all__ = ["AbstractRepository", "SQLAlchemyRepository", "SQLAlchemyTableRepository", "SQLAlchemyTableManager"] class AbstractRepository(abc.ABC): + """Base class for repositories.""" + ... class SQLAlchemyRepository(AbstractRepository): - _table: t.ClassVar[sqlalchemy.Table] + """Base class for SQLAlchemy repositories. It provides a connection to the database.""" def __init__(self, connection: "AsyncConnection"): self._connection = connection + def __eq__(self, other): + return isinstance(other, SQLAlchemyRepository) and self._connection == other._connection + + +class SQLAlchemyTableManager: + def __init__(self, table: sqlalchemy.Table, connection: "AsyncConnection"): + self._connection = connection + self.table = table + def __eq__(self, other): return ( - isinstance(other, SQLAlchemyRepository) - and self._table == other._table + isinstance(other, SQLAlchemyTableManager) and self._connection == other._connection + and self.table == other.table ) @property @@ -42,7 +53,8 @@ def primary_key(self) -> sqlalchemy.Column: :return: sqlalchemy.Column: The primary key of the model. :raises: exceptions.IntegrityError: If the model has a composed primary key. """ - model_pk_columns = list(sqlalchemy.inspect(self._table).primary_key.columns.values()) + + model_pk_columns = list(sqlalchemy.inspect(self.table).primary_key.columns.values()) if len(model_pk_columns) != 1: raise exceptions.IntegrityError("Composed primary keys are not supported") @@ -52,79 +64,189 @@ def primary_key(self) -> sqlalchemy.Column: async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]: """Creates a new element in the repository. + If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns + the primary key of the element. + :param data: The data to create the element. :return: The primary key of the created element. :raises: exceptions.IntegrityError: If the element already exists. """ try: - result = await self._connection.execute(sqlalchemy.insert(self._table).values(**data)) + result = await self._connection.execute(sqlalchemy.insert(self.table).values(**data)) except sqlalchemy.exc.IntegrityError as e: raise exceptions.IntegrityError(str(e)) return tuple(result.inserted_primary_key) if result.inserted_primary_key else None - async def retrieve(self, id_: t.Any) -> types.Schema: + async def retrieve(self, id: t.Any) -> types.Schema: """Retrieves an element from the repository. - :param id_: The primary key of the element. + If the element does not exist, it raises a `NotFoundError`. + + :param id: The primary key of the element. :return: The element. :raises: exceptions.NotFoundError: If the element does not exist. """ element = ( await self._connection.execute( - sqlalchemy.select(self._table).where(self._table.c[self.primary_key.name] == id_) + sqlalchemy.select(self.table).where(self.table.c[self.primary_key.name] == id) ) ).first() if element is None: - raise exceptions.NotFoundError(str(id_)) + raise exceptions.NotFoundError(str(id)) return types.Schema(element._asdict()) - async def update(self, id_: t.Any, data: types.Schema) -> types.Schema: + async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema: """Updates an element in the repository. - :param id_: The primary key of the element. + If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated + element. + + :param id: The primary key of the element. :param data: The data to update the element. :return: The updated element. :raises: exceptions.NotFoundError: If the element does not exist. """ pk = self.primary_key result = await self._connection.execute( - sqlalchemy.update(self._table).where(self._table.c[pk.name] == id_).values(**data) + sqlalchemy.update(self.table).where(self.table.c[pk.name] == id).values(**data) ) + if result.rowcount == 0: - raise exceptions.NotFoundError(id_) - return types.Schema({pk.name: id_, **data}) + raise exceptions.NotFoundError(id) + + return types.Schema({pk.name: id, **data}) - async def delete(self, id_: t.Any) -> None: + async def delete(self, id: t.Any) -> None: """Deletes an element from the repository. - :param id_: The primary key of the element. + If the element does not exist, it raises a `NotFoundError`. + + :param id: The primary key of the element. :raises: exceptions.NotFoundError: If the element does not exist. """ result = await self._connection.execute( - sqlalchemy.delete(self._table).where(self._table.c[self.primary_key.name] == id_) + sqlalchemy.delete(self.table).where(self.table.c[self.primary_key.name] == id) ) + if result.rowcount == 0: - raise exceptions.NotFoundError(id_) + raise exceptions.NotFoundError(id) async def list(self, *clauses, **filters) -> t.List[types.Schema]: """Lists all the elements in the repository. + If no elements are found, it returns an empty list. If no clauses or filters are given, it returns all the + elements in the repository. + + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + :param clauses: Clauses to filter the elements. :param filters: Filters to filter the elements. :return: The elements. """ - query = sqlalchemy.select(self._table) - where_clauses = tuple(clauses) + tuple(self._table.c[k] == v for k, v in filters.items()) + query = sqlalchemy.select(self.table) + + where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items()) if where_clauses: query = query.where(sqlalchemy.and_(*where_clauses)) + return [types.Schema(row._asdict()) async for row in await self._connection.stream(query)] async def drop(self) -> int: """Drops all the elements in the repository. + Returns the number of elements dropped. + :return: The number of elements dropped. """ - result = await self._connection.execute(sqlalchemy.delete(self._table)) + result = await self._connection.execute(sqlalchemy.delete(self.table)) return result.rowcount + + +class SQLAlchemyTableRepository(SQLAlchemyRepository): + _table: t.ClassVar[sqlalchemy.Table] + + def __init__(self, connection: "AsyncConnection"): + super().__init__(connection) + self._table_manager = SQLAlchemyTableManager(self._table, connection) + + def __eq__(self, other): + return isinstance(other, SQLAlchemyTableRepository) and self._table == other._table and super().__eq__(other) + + async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]: + """Creates a new element in the repository. + + If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns + the primary key of the element. + + :param data: The data to create the element. + :return: The primary key of the created element. + :raises: exceptions.IntegrityError: If the element already exists. + """ + return await self._table_manager.create(data) + + async def retrieve(self, id: t.Any) -> types.Schema: + """Retrieves an element from the repository. + + If the element does not exist, it raises a `NotFoundError`. + + :param id: The primary key of the element. + :return: The element. + :raises: exceptions.NotFoundError: If the element does not exist. + """ + return await self._table_manager.retrieve(id) + + async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema: + """Updates an element in the repository. + + If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated + element. + + :param id: The primary key of the element. + :param data: The data to update the element. + :return: The updated element. + :raises: exceptions.NotFoundError: If the element does not exist. + """ + return await self._table_manager.update(id, data) + + async def delete(self, id: t.Any) -> None: + """Deletes an element from the repository. + + If the element does not exist, it raises a `NotFoundError`. + + :param id: The primary key of the element. + :raises: exceptions.NotFoundError: If the element does not exist. + """ + return await self._table_manager.delete(id) + + async def list(self, *clauses, **filters) -> t.List[types.Schema]: + """Lists all the elements in the repository. + + Lists all the elements in the repository that match the clauses and filters. If no clauses or filters are given, + it returns all the elements in the repository. If no elements are found, it returns an empty list. + + Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using + exact values to specific columns. Clauses and filters can be combined. + + Clause example: `table.c["id"]._in((1, 2, 3))` + Filter example: `id=1` + + :param clauses: Clauses to filter the elements. + :param filters: Filters to filter the elements. + :return: The elements. + """ + return await self._table_manager.list(*clauses, **filters) + + async def drop(self) -> int: + """Drops all the elements in the repository. + + Returns the number of elements dropped. + + :return: The number of elements dropped. + """ + return await self._table_manager.drop() diff --git a/flama/resources/rest.py b/flama/resources/rest.py index e9c05e58..f2d3d40c 100644 --- a/flama/resources/rest.py +++ b/flama/resources/rest.py @@ -2,7 +2,7 @@ import typing as t import uuid -from flama.ddd.repositories import SQLAlchemyRepository +from flama.ddd.repositories import SQLAlchemyTableRepository from flama.resources import data_structures from flama.resources.exceptions import ResourceAttributeError from flama.resources.resource import BaseResource, ResourceType @@ -55,7 +55,7 @@ def __new__(mcs, name: str, bases: t.Tuple[type], namespace: t.Dict[str, t.Any]) namespace.setdefault("_meta", data_structures.Metadata()).namespaces.update( { "rest": {"model": model, "schemas": resource_schemas}, - "ddd": {"repository": type(f"{name}Repository", (SQLAlchemyRepository,), {"_table": model.table})}, + "ddd": {"repository": type(f"{name}Repository", (SQLAlchemyTableRepository,), {"_table": model.table})}, } ) diff --git a/flama/resources/workers.py b/flama/resources/workers.py index 040fb2fd..45b5d21e 100644 --- a/flama/resources/workers.py +++ b/flama/resources/workers.py @@ -4,16 +4,18 @@ if t.TYPE_CHECKING: from flama import Flama - from flama.ddd.repositories import SQLAlchemyRepository + from flama.ddd.repositories import SQLAlchemyTableRepository class FlamaWorker(SQLAlchemyWorker): + _repositories: t.ClassVar[t.Dict[str, t.Type["SQLAlchemyTableRepository"]]] + def __init__(self, app: t.Optional["Flama"] = None): super().__init__(app) - self._init_repositories: t.Optional[t.Dict[str, "SQLAlchemyRepository"]] = None + self._init_repositories: t.Optional[t.Dict[str, "SQLAlchemyTableRepository"]] = None @property - def repositories(self) -> t.Dict[str, "SQLAlchemyRepository"]: + def repositories(self) -> t.Dict[str, "SQLAlchemyTableRepository"]: assert self._init_repositories, "Repositories not initialized" return self._init_repositories @@ -26,5 +28,5 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() del self._init_repositories - def add_repository(self, name: str, cls: t.Type["SQLAlchemyRepository"]) -> None: + def add_repository(self, name: str, cls: t.Type["SQLAlchemyTableRepository"]) -> None: self._repositories[name] = cls diff --git a/tests/ddd/test_repositories.py b/tests/ddd/test_repositories.py index 5c5cafa4..d2001f58 100644 --- a/tests/ddd/test_repositories.py +++ b/tests/ddd/test_repositories.py @@ -1,87 +1,108 @@ +import uuid +from unittest.mock import Mock, call, patch + import pytest import sqlalchemy from sqlalchemy.ext.asyncio import AsyncConnection from flama import Flama -from flama.ddd import SQLAlchemyRepository, exceptions +from flama.ddd import SQLAlchemyRepository, SQLAlchemyTableManager, SQLAlchemyTableRepository, exceptions from flama.sqlalchemy import SQLAlchemyModule from tests.utils import SQLAlchemyContext -class TestCaseSQLAlchemyRepository: - @pytest.fixture(scope="function") - def app(self): - return Flama(schema=None, docs=None, modules={SQLAlchemyModule("sqlite+aiosqlite://")}) +@pytest.fixture(scope="function") +def app(): + return Flama(schema=None, docs=None, modules={SQLAlchemyModule("sqlite+aiosqlite://")}) - @pytest.fixture(scope="function") - async def connection(self, client): - # Exactly the same behavior than 'async with worker' - connection_: AsyncConnection = client.app.sqlalchemy.engine.connect() - await connection_.__aenter__() - transaction = connection_.begin() - await transaction.__aenter__() - yield connection_ +@pytest.fixture(scope="function") +async def connection(client): + # Exactly the same behavior than 'async with worker' + connection_: AsyncConnection = client.app.sqlalchemy.engine.connect() + await connection_.__aenter__() + transaction = connection_.begin() + await transaction.__aenter__() - await transaction.__aexit__(None, None, None) - await connection_.__aexit__(None, None, None) + yield connection_ - @pytest.fixture(scope="function") - def tables(self, app): - return { - "single": sqlalchemy.Table( - "repository_table_single_pk", - app.sqlalchemy.metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("name", sqlalchemy.String, nullable=False), - ), - "composed": sqlalchemy.Table( - "repository_table_composed_pk", - app.sqlalchemy.metadata, - sqlalchemy.Column("id_first", sqlalchemy.Integer, primary_key=True), - sqlalchemy.Column("id_second", sqlalchemy.Integer, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String, nullable=False), - ), - } + await transaction.__aexit__(None, None, None) + await connection_.__aexit__(None, None, None) - @pytest.fixture(scope="function") - async def table(self, client, tables): - table = tables["single"] - async with SQLAlchemyContext(client.app, [table]): - yield table +@pytest.fixture(scope="function") +def tables(app): + return { + "single": sqlalchemy.Table( + "repository_table_single_pk", + app.sqlalchemy.metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("name", sqlalchemy.String, nullable=False), + ), + "composed": sqlalchemy.Table( + "repository_table_composed_pk", + app.sqlalchemy.metadata, + sqlalchemy.Column("id_first", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("id_second", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("name", sqlalchemy.String, nullable=False), + ), + } + + +@pytest.fixture(scope="function") +async def table(client, tables): + table = tables["single"] + async with SQLAlchemyContext(client.app, [table]): + yield table + + +class TestCaseSQLAlchemyRepository: @pytest.fixture(scope="function") - def repository(self, table, connection): + def connection(self): + return Mock(spec=AsyncConnection) + + async def test_init(self, connection): class Repository(SQLAlchemyRepository): - _table = table + ... + + repository = Repository(connection) + + assert repository._connection == connection - return Repository(connection) + def test_eq(self, connection): + assert SQLAlchemyRepository(connection) == SQLAlchemyRepository(connection) + +class TestCaseSQLAlchemyTableManager: @pytest.fixture(scope="function") - async def repository_select(self, request, client, tables, connection): - table = tables[request.param] + def table_manager(self, table, connection): + return SQLAlchemyTableManager(table, connection) - async with SQLAlchemyContext(client.app, [table]): + async def test_init(self, table, connection): + table_manager = SQLAlchemyTableManager(table, connection) - class Repository(SQLAlchemyRepository): - _table = tables[request.param] + assert table_manager._connection == connection + assert table_manager.table == table - yield Repository(connection) + def test_eq(self, table, connection): + assert SQLAlchemyTableManager(table, connection) == SQLAlchemyTableManager(table, connection) @pytest.mark.parametrize( - ["repository_select", "result", "exception"], + ["table", "result", "exception"], ( pytest.param("single", "id", None, id="single_pk"), pytest.param( "composed", None, exceptions.IntegrityError("Composed primary keys are not supported"), id="composed_pk" ), ), - indirect=["repository_select", "exception"], + indirect=["exception"], ) - async def test_primary_key(self, repository_select, result, exception): + async def test_primary_key(self, table, result, exception, tables): + table_manager = SQLAlchemyTableManager(tables[table], Mock()) + with exception: - assert repository_select.primary_key.name == result + assert table_manager.primary_key.name == result @pytest.mark.parametrize( ["data", "result", "exception"], @@ -91,9 +112,9 @@ async def test_primary_key(self, repository_select, result, exception): ), indirect=["exception"], ) - async def test_create(self, repository, data, result, exception): + async def test_create(self, table_manager, data, result, exception): with exception: - assert await repository.create(data) == result + assert await table_manager.create(data) == result @pytest.mark.parametrize( ["data", "result", "exception"], @@ -103,11 +124,11 @@ async def test_create(self, repository, data, result, exception): ), indirect=["exception"], ) - async def test_retrieve(self, data, result, exception, repository): - await repository.create({"name": "foo"}) + async def test_retrieve(self, data, result, exception, table_manager): + await table_manager.create({"name": "foo"}) with exception: - assert await repository.retrieve(data) == result + assert await table_manager.retrieve(data) == result @pytest.mark.parametrize( ["data", "result", "exception"], @@ -117,26 +138,26 @@ async def test_retrieve(self, data, result, exception, repository): ), indirect=["exception"], ) - async def test_update(self, data, result, exception, repository): + async def test_update(self, data, result, exception, table_manager): id_, data_ = data - await repository.create(data_) + await table_manager.create(data_) with exception: - assert await repository.update(id_, {"name": "foo"}) == result + assert await table_manager.update(id_, {"name": "foo"}) == result @pytest.mark.parametrize( - ["data", "result", "exception"], + ["data", "exception"], ( - pytest.param(1, {"id": 1, "name": "foo"}, None, id="ok"), - pytest.param(2, None, exceptions.NotFoundError(1), id="not_found"), + pytest.param(1, None, id="ok"), + pytest.param(2, exceptions.NotFoundError(1), id="not_found"), ), indirect=["exception"], ) - async def test_delete(self, data, result, exception, repository): - await repository.create({"name": "foo"}) + async def test_delete(self, data, exception, table_manager): + await table_manager.create({"name": "foo"}) with exception: - await repository.delete(data) + await table_manager.delete(data) @pytest.mark.parametrize( ["clauses", "filters", "result"], @@ -146,17 +167,97 @@ async def test_delete(self, data, result, exception, repository): pytest.param([], {"name": "foo"}, [{"id": 1, "name": "foo"}], id="filters"), ), ) - async def test_list(self, clauses, filters, result, repository): - await repository.create({"name": "foo"}) - await repository.create({"name": "bar"}) + async def test_list(self, clauses, filters, result, table, table_manager): + await table_manager.create({"name": "foo"}) + await table_manager.create({"name": "bar"}) - r = await repository.list(*[c(repository._table.c["name"]) for c in clauses], **filters) + r = await table_manager.list(*[c(table.c["name"]) for c in clauses], **filters) assert r == result - async def test_drop(self, repository): - await repository.create({"name": "foo"}) + async def test_drop(self, table_manager): + await table_manager.create({"name": "foo"}) - result = await repository.drop() + result = await table_manager.drop() assert result == 1 + + +class TestCaseSQLAlchemyTableRepository: + @pytest.fixture(scope="function") + def table(self): + return Mock(spec=sqlalchemy.Table) + + @pytest.fixture(scope="function") + def connection(self): + return Mock(spec=AsyncConnection) + + @pytest.fixture(scope="function") + def table_manager(self): + return Mock(spec=SQLAlchemyTableManager) + + @pytest.fixture(scope="function") + def repository(self, table, connection, table_manager): + class Repository(SQLAlchemyTableRepository): + _table = table + + r = Repository(connection) + with patch.object(r, "_table_manager", table_manager): + yield r + + async def test_init(self, table, connection): + class Repository(SQLAlchemyTableRepository): + _table = table + + repository = Repository(connection) + + assert repository._connection == connection + assert repository._table_manager == SQLAlchemyTableManager(table, connection) + + def test_eq(self, table, connection): + class Repository(SQLAlchemyTableRepository): + _table = table + + assert Repository(connection) == Repository(connection) + + async def test_create(self, repository, table_manager): + data = {"foo": "bar"} + + await repository.create(data) + + assert table_manager.create.call_args_list == [call(data)] + + async def test_retrieve(self, repository, table_manager): + id = uuid.uuid4() + + await repository.retrieve(id) + + assert table_manager.retrieve.call_args_list == [call(id)] + + async def test_update(self, repository, table_manager): + id = uuid.uuid4() + data = {"foo": "bar"} + + await repository.update(id, data) + + assert table_manager.update.call_args_list == [call(id, data)] + + async def test_delete(self, repository, table_manager): + id = uuid.uuid4() + + await repository.delete(id) + + assert table_manager.delete.call_args_list == [call(id)] + + async def test_list(self, repository, table_manager): + clauses = [Mock(), Mock()] + filters = {"foo": "bar"} + + await repository.list(*clauses, **filters) + + assert table_manager.list.call_args_list == [call(*clauses, **filters)] + + async def test_drop(self, repository, table_manager): + await repository.drop() + + assert table_manager.drop.call_args_list == [call()] diff --git a/tests/ddd/test_workers.py b/tests/ddd/test_workers.py index 2ab81b3e..7f4ae60e 100644 --- a/tests/ddd/test_workers.py +++ b/tests/ddd/test_workers.py @@ -60,19 +60,41 @@ async def test_begin(self, worker): assert connection_mock.begin.call_args_list == [call()] assert connection_mock.begin.await_args_list == [call()] - async def test_close(self, worker): + @pytest.mark.parametrize( + ["transaction", "transaction_calls", "connection", "connection_calls"], + ( + pytest.param(None, [], None, [], id="no_transaction_no_connection"), + pytest.param( + AsyncMock(AsyncTransaction), [call(None, None, None)], None, [], id="transaction_no_connection" + ), + pytest.param( + None, [], AsyncMock(AsyncConnection), [call(None, None, None)], id="no_transaction_connection" + ), + pytest.param( + AsyncMock(AsyncTransaction), + [call(None, None, None)], + AsyncMock(AsyncConnection), + [call(None, None, None)], + id="transaction_connection", + ), + ), + ) + async def test_close(self, transaction, transaction_calls, connection, connection_calls, worker): assert not hasattr(worker, "_transaction") assert not hasattr(worker, "_connection") - connection_mock = AsyncMock(spec=AsyncConnection) - transaction_mock = AsyncMock(spec=AsyncTransaction) - worker._connection = connection_mock - worker._transaction = transaction_mock + if connection: + worker._connection = connection + if transaction: + worker._transaction = transaction await worker.close() - assert transaction_mock.__aexit__.await_args_list == [call(None, None, None)] - assert connection_mock.__aexit__.await_args_list == [call(None, None, None)] + if transaction: + assert transaction.__aexit__.await_args_list == transaction_calls + if connection: + assert connection.__aexit__.await_args_list == connection_calls + assert not hasattr(worker, "_transaction") assert not hasattr(worker, "_connection") diff --git a/tests/resources/test_workers.py b/tests/resources/test_workers.py index 13d45e82..2a4703db 100644 --- a/tests/resources/test_workers.py +++ b/tests/resources/test_workers.py @@ -4,7 +4,7 @@ import sqlalchemy from flama import Flama -from flama.ddd import SQLAlchemyRepository +from flama.ddd import SQLAlchemyTableRepository from flama.resources.workers import FlamaWorker from flama.sqlalchemy import SQLAlchemyModule @@ -23,7 +23,7 @@ class FooWorker(FlamaWorker): @pytest.fixture(scope="function") def repository(self): - class FooRepository(SQLAlchemyRepository): + class FooRepository(SQLAlchemyTableRepository): _table = Mock(spec=sqlalchemy.Table) return FooRepository