Skip to content

Commit

Permalink
🐛 Atomic operations on SQLAlchemy connections management
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed May 13, 2024
1 parent e9aa6aa commit 77dc421
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
19 changes: 19 additions & 0 deletions flama/ddd/workers/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import asyncio
import logging
import threading
import typing as t

from flama.ddd.workers.base import AbstractWorker

if t.TYPE_CHECKING:
from flama.applications import Flama

try:
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncTransaction
except Exception: # pragma: no cover
...

__all__ = ["SQLAlchemyWorker"]

logger = logging.getLogger(__name__)


class SQLAlchemyWorker(AbstractWorker):
"""Worker for SQLAlchemy.
Expand All @@ -20,6 +27,11 @@ class SQLAlchemyWorker(AbstractWorker):
_connection: "AsyncConnection"
_transaction: "AsyncTransaction"

def __init__(self, app: t.Optional["Flama"] = None):
super().__init__(app=app)
self._locks = {x: threading.Lock() for x in ("begin", "end")}
self._lock = asyncio.Lock()

@property
def connection(self) -> "AsyncConnection":
"""Connection to the database.
Expand Down Expand Up @@ -67,10 +79,13 @@ async def begin(self) -> None:
Initialize the connection, begin a transaction, and create the repositories.
"""
await self._lock.acquire()

await self.begin_transaction()

for repository, repository_class in self._repositories.items():
setattr(self, repository, repository_class(self.connection))
logger.debug("Start unit of work")

async def end(self, *, rollback: bool = False) -> None:
"""End a unit of work.
Expand All @@ -84,6 +99,10 @@ async def end(self, *, rollback: bool = False) -> None:
for repository in self._repositories.keys():
delattr(self, repository)

self._lock.release()

logger.debug("End unit of work")

async def commit(self):
"""Commit the unit of work."""
await self.connection.commit()
Expand Down
12 changes: 8 additions & 4 deletions flama/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
import typing as t

from flama import exceptions
Expand All @@ -23,6 +24,8 @@

__all__ = ["metadata", "SQLAlchemyModule"]

logger = logging.getLogger(__name__)


class ConnectionManager(abc.ABC):
"""Abstract class for connection managers.
Expand All @@ -37,6 +40,7 @@ def __init__(self, engine: "AsyncEngine") -> None:
"""
self._engine = engine

@abc.abstractmethod
async def open(self) -> "AsyncConnection":
"""Open a new connection to the database.
Expand Down Expand Up @@ -212,7 +216,7 @@ async def open(self) -> "AsyncConnection":
:return: Database connection.
"""
connection = self._engine.connect()
await connection.__aenter__()
await connection.start()
self._connections.add(connection)
return connection

Expand All @@ -228,8 +232,8 @@ async def close(self, connection: "AsyncConnection") -> None:
if connection in self._transactions:
await self.end(self._transactions[connection])

await connection.__aexit__(None, None, None)
self._connections.remove(connection)
await connection.close()

async def begin(self, connection: "AsyncConnection") -> "AsyncTransaction":
"""Begin a new transaction.
Expand Down Expand Up @@ -258,13 +262,13 @@ async def end(self, transaction: "AsyncTransaction", *, rollback: bool = False)
if transaction.connection not in self._transactions:
raise exceptions.SQLAlchemyError("Transaction not started")

del self._transactions[transaction.connection]

if rollback:
await transaction.rollback()
else:
await transaction.commit()

del self._transactions[transaction.connection]


class SQLAlchemyModule(Module):
"""SQLAlchemy module.
Expand Down

0 comments on commit 77dc421

Please sign in to comment.