Skip to content

Commit

Permalink
✨ Decouple table manager logic from repositories (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent 96d1150 commit 96229a1
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 108 deletions.
166 changes: 144 additions & 22 deletions flama/ddd/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,35 @@
from sqlalchemy.ext.asyncio import AsyncConnection


__all__ = ["AbstractRepository", "SQLAlchemyRepository"]
__all__ = ["AbstractRepository", "SQLAlchemyRepository", "SQLAlchemyTableRepository", "SQLAlchemyTableManager"]


class AbstractRepository(abc.ABC):
"""Base class for repositories."""

...


class SQLAlchemyRepository(AbstractRepository):
_table: t.ClassVar[sqlalchemy.Table]
"""Base class for SQLAlchemy repositories. It provides a connection to the database."""

def __init__(self, connection: "AsyncConnection"):
self._connection = connection

def __eq__(self, other):
return isinstance(other, SQLAlchemyRepository) and self._connection == other._connection


class SQLAlchemyTableManager:
def __init__(self, table: sqlalchemy.Table, connection: "AsyncConnection"):
self._connection = connection
self.table = table

def __eq__(self, other):
return (
isinstance(other, SQLAlchemyRepository)
and self._table == other._table
isinstance(other, SQLAlchemyTableManager)
and self._connection == other._connection
and self.table == other.table
)

@property
Expand All @@ -42,7 +53,8 @@ def primary_key(self) -> sqlalchemy.Column:
:return: sqlalchemy.Column: The primary key of the model.
:raises: exceptions.IntegrityError: If the model has a composed primary key.
"""
model_pk_columns = list(sqlalchemy.inspect(self._table).primary_key.columns.values())

model_pk_columns = list(sqlalchemy.inspect(self.table).primary_key.columns.values())

if len(model_pk_columns) != 1:
raise exceptions.IntegrityError("Composed primary keys are not supported")
Expand All @@ -52,79 +64,189 @@ def primary_key(self) -> sqlalchemy.Column:
async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]:
"""Creates a new element in the repository.
If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns
the primary key of the element.
:param data: The data to create the element.
:return: The primary key of the created element.
:raises: exceptions.IntegrityError: If the element already exists.
"""
try:
result = await self._connection.execute(sqlalchemy.insert(self._table).values(**data))
result = await self._connection.execute(sqlalchemy.insert(self.table).values(**data))
except sqlalchemy.exc.IntegrityError as e:
raise exceptions.IntegrityError(str(e))
return tuple(result.inserted_primary_key) if result.inserted_primary_key else None

async def retrieve(self, id_: t.Any) -> types.Schema:
async def retrieve(self, id: t.Any) -> types.Schema:
"""Retrieves an element from the repository.
:param id_: The primary key of the element.
If the element does not exist, it raises a `NotFoundError`.
:param id: The primary key of the element.
:return: The element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
element = (
await self._connection.execute(
sqlalchemy.select(self._table).where(self._table.c[self.primary_key.name] == id_)
sqlalchemy.select(self.table).where(self.table.c[self.primary_key.name] == id)
)
).first()

if element is None:
raise exceptions.NotFoundError(str(id_))
raise exceptions.NotFoundError(str(id))

return types.Schema(element._asdict())

async def update(self, id_: t.Any, data: types.Schema) -> types.Schema:
async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema:
"""Updates an element in the repository.
:param id_: The primary key of the element.
If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated
element.
:param id: The primary key of the element.
:param data: The data to update the element.
:return: The updated element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
pk = self.primary_key
result = await self._connection.execute(
sqlalchemy.update(self._table).where(self._table.c[pk.name] == id_).values(**data)
sqlalchemy.update(self.table).where(self.table.c[pk.name] == id).values(**data)
)

if result.rowcount == 0:
raise exceptions.NotFoundError(id_)
return types.Schema({pk.name: id_, **data})
raise exceptions.NotFoundError(id)

return types.Schema({pk.name: id, **data})

async def delete(self, id_: t.Any) -> None:
async def delete(self, id: t.Any) -> None:
"""Deletes an element from the repository.
:param id_: The primary key of the element.
If the element does not exist, it raises a `NotFoundError`.
:param id: The primary key of the element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
result = await self._connection.execute(
sqlalchemy.delete(self._table).where(self._table.c[self.primary_key.name] == id_)
sqlalchemy.delete(self.table).where(self.table.c[self.primary_key.name] == id)
)

if result.rowcount == 0:
raise exceptions.NotFoundError(id_)
raise exceptions.NotFoundError(id)

async def list(self, *clauses, **filters) -> t.List[types.Schema]:
"""Lists all the elements in the repository.
If no elements are found, it returns an empty list. If no clauses or filters are given, it returns all the
elements in the repository.
Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using
exact values to specific columns. Clauses and filters can be combined.
Clause example: `table.c["id"]._in((1, 2, 3))`
Filter example: `id=1`
:param clauses: Clauses to filter the elements.
:param filters: Filters to filter the elements.
:return: The elements.
"""
query = sqlalchemy.select(self._table)
where_clauses = tuple(clauses) + tuple(self._table.c[k] == v for k, v in filters.items())
query = sqlalchemy.select(self.table)

where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items())
if where_clauses:
query = query.where(sqlalchemy.and_(*where_clauses))

return [types.Schema(row._asdict()) async for row in await self._connection.stream(query)]

async def drop(self) -> int:
"""Drops all the elements in the repository.
Returns the number of elements dropped.
:return: The number of elements dropped.
"""
result = await self._connection.execute(sqlalchemy.delete(self._table))
result = await self._connection.execute(sqlalchemy.delete(self.table))
return result.rowcount


class SQLAlchemyTableRepository(SQLAlchemyRepository):
_table: t.ClassVar[sqlalchemy.Table]

def __init__(self, connection: "AsyncConnection"):
super().__init__(connection)
self._table_manager = SQLAlchemyTableManager(self._table, connection)

def __eq__(self, other):
return isinstance(other, SQLAlchemyTableRepository) and self._table == other._table and super().__eq__(other)

async def create(self, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> t.Optional[t.Tuple[t.Any, ...]]:
"""Creates a new element in the repository.
If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns
the primary key of the element.
:param data: The data to create the element.
:return: The primary key of the created element.
:raises: exceptions.IntegrityError: If the element already exists.
"""
return await self._table_manager.create(data)

async def retrieve(self, id: t.Any) -> types.Schema:
"""Retrieves an element from the repository.
If the element does not exist, it raises a `NotFoundError`.
:param id: The primary key of the element.
:return: The element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
return await self._table_manager.retrieve(id)

async def update(self, id: t.Any, data: t.Union[t.Dict[str, t.Any], types.Schema]) -> types.Schema:
"""Updates an element in the repository.
If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated
element.
:param id: The primary key of the element.
:param data: The data to update the element.
:return: The updated element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
return await self._table_manager.update(id, data)

async def delete(self, id: t.Any) -> None:
"""Deletes an element from the repository.
If the element does not exist, it raises a `NotFoundError`.
:param id: The primary key of the element.
:raises: exceptions.NotFoundError: If the element does not exist.
"""
return await self._table_manager.delete(id)

async def list(self, *clauses, **filters) -> t.List[types.Schema]:
"""Lists all the elements in the repository.
Lists all the elements in the repository that match the clauses and filters. If no clauses or filters are given,
it returns all the elements in the repository. If no elements are found, it returns an empty list.
Clauses are used to filter the elements using sqlalchemy clauses. Filters are used to filter the elements using
exact values to specific columns. Clauses and filters can be combined.
Clause example: `table.c["id"]._in((1, 2, 3))`
Filter example: `id=1`
:param clauses: Clauses to filter the elements.
:param filters: Filters to filter the elements.
:return: The elements.
"""
return await self._table_manager.list(*clauses, **filters)

async def drop(self) -> int:
"""Drops all the elements in the repository.
Returns the number of elements dropped.
:return: The number of elements dropped.
"""
return await self._table_manager.drop()
4 changes: 2 additions & 2 deletions flama/resources/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as t
import uuid

from flama.ddd.repositories import SQLAlchemyRepository
from flama.ddd.repositories import SQLAlchemyTableRepository
from flama.resources import data_structures
from flama.resources.exceptions import ResourceAttributeError
from flama.resources.resource import BaseResource, ResourceType
Expand Down Expand Up @@ -55,7 +55,7 @@ def __new__(mcs, name: str, bases: t.Tuple[type], namespace: t.Dict[str, t.Any])
namespace.setdefault("_meta", data_structures.Metadata()).namespaces.update(
{
"rest": {"model": model, "schemas": resource_schemas},
"ddd": {"repository": type(f"{name}Repository", (SQLAlchemyRepository,), {"_table": model.table})},
"ddd": {"repository": type(f"{name}Repository", (SQLAlchemyTableRepository,), {"_table": model.table})},
}
)

Expand Down
10 changes: 6 additions & 4 deletions flama/resources/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

if t.TYPE_CHECKING:
from flama import Flama
from flama.ddd.repositories import SQLAlchemyRepository
from flama.ddd.repositories import SQLAlchemyTableRepository


class FlamaWorker(SQLAlchemyWorker):
_repositories: t.ClassVar[t.Dict[str, t.Type["SQLAlchemyTableRepository"]]]

def __init__(self, app: t.Optional["Flama"] = None):
super().__init__(app)
self._init_repositories: t.Optional[t.Dict[str, "SQLAlchemyRepository"]] = None
self._init_repositories: t.Optional[t.Dict[str, "SQLAlchemyTableRepository"]] = None

@property
def repositories(self) -> t.Dict[str, "SQLAlchemyRepository"]:
def repositories(self) -> t.Dict[str, "SQLAlchemyTableRepository"]:
assert self._init_repositories, "Repositories not initialized"
return self._init_repositories

Expand All @@ -26,5 +28,5 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
del self._init_repositories

def add_repository(self, name: str, cls: t.Type["SQLAlchemyRepository"]) -> None:
def add_repository(self, name: str, cls: t.Type["SQLAlchemyTableRepository"]) -> None:
self._repositories[name] = cls
Loading

0 comments on commit 96229a1

Please sign in to comment.