Skip to content

Commit

Permalink
🐛 Atomic operations on SQLAlchemy connections management
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed May 13, 2024
1 parent e9aa6aa commit d22272f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
10 changes: 10 additions & 0 deletions flama/ddd/workers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
import asyncio
import inspect
import logging
import typing as t

from flama.ddd.repositories import AbstractRepository
Expand All @@ -8,9 +10,12 @@
if t.TYPE_CHECKING:
from flama import Flama

logger = logging.getLogger(__name__)

Repositories = t.NewType("Repositories", t.Dict[str, t.Type[AbstractRepository]])

__all__ = ["WorkerType", "AbstractWorker"]


class WorkerType(abc.ABCMeta):
"""Metaclass for workers.
Expand Down Expand Up @@ -57,6 +62,7 @@ def __init__(self, app: t.Optional["Flama"] = None):
:param app: Application instance.
"""
self._app = app
self._lock = asyncio.Lock()

@property
def app(self) -> "Flama":
Expand Down Expand Up @@ -97,12 +103,16 @@ async def end(self, *, rollback: bool = False) -> None:

async def __aenter__(self) -> "AbstractWorker":
"""Start a unit of work."""
await self._lock.acquire()
logger.debug("Start unit of work")
await self.begin()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""End a unit of work."""
await self.end(rollback=exc_type is not None)
logger.debug("End unit of work")
self._lock.release()

@abc.abstractmethod
async def commit(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions flama/ddd/workers/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing as t

from flama.ddd.workers.base import AbstractWorker
Expand All @@ -10,6 +11,8 @@

__all__ = ["SQLAlchemyWorker"]

logger = logging.getLogger(__name__)


class SQLAlchemyWorker(AbstractWorker):
"""Worker for SQLAlchemy.
Expand Down
12 changes: 8 additions & 4 deletions flama/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
import typing as t

from flama import exceptions
Expand All @@ -23,6 +24,8 @@

__all__ = ["metadata", "SQLAlchemyModule"]

logger = logging.getLogger(__name__)


class ConnectionManager(abc.ABC):
"""Abstract class for connection managers.
Expand All @@ -37,6 +40,7 @@ def __init__(self, engine: "AsyncEngine") -> None:
"""
self._engine = engine

@abc.abstractmethod
async def open(self) -> "AsyncConnection":
"""Open a new connection to the database.
Expand Down Expand Up @@ -212,7 +216,7 @@ async def open(self) -> "AsyncConnection":
:return: Database connection.
"""
connection = self._engine.connect()
await connection.__aenter__()
await connection.start()
self._connections.add(connection)
return connection

Expand All @@ -228,8 +232,8 @@ async def close(self, connection: "AsyncConnection") -> None:
if connection in self._transactions:
await self.end(self._transactions[connection])

await connection.__aexit__(None, None, None)
self._connections.remove(connection)
await connection.close()

async def begin(self, connection: "AsyncConnection") -> "AsyncTransaction":
"""Begin a new transaction.
Expand Down Expand Up @@ -258,13 +262,13 @@ async def end(self, transaction: "AsyncTransaction", *, rollback: bool = False)
if transaction.connection not in self._transactions:
raise exceptions.SQLAlchemyError("Transaction not started")

del self._transactions[transaction.connection]

if rollback:
await transaction.rollback()
else:
await transaction.commit()

del self._transactions[transaction.connection]


class SQLAlchemyModule(Module):
"""SQLAlchemy module.
Expand Down

0 comments on commit d22272f

Please sign in to comment.