Skip to content

Commit

Permalink
🐛 Atomic operations on SQLAlchemy connections management
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed May 13, 2024
1 parent e9aa6aa commit 400dd85
Showing 1 changed file with 48 additions and 25 deletions.
73 changes: 48 additions & 25 deletions flama/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import abc
import asyncio
import logging
import threading
import typing as t

from flama import exceptions
Expand All @@ -23,6 +26,8 @@

__all__ = ["metadata", "SQLAlchemyModule"]

logger = logging.getLogger(__name__)


class ConnectionManager(abc.ABC):
"""Abstract class for connection managers.
Expand All @@ -36,7 +41,11 @@ def __init__(self, engine: "AsyncEngine") -> None:
:param engine: SQLAlchemy engine.
"""
self._engine = engine
self._lock_open = threading.Lock()
self._lock_close = threading.Lock()
self._alock = asyncio.Lock()

@abc.abstractmethod
async def open(self) -> "AsyncConnection":
"""Open a new connection to the database.
Expand Down Expand Up @@ -123,14 +132,16 @@ async def open(self) -> "AsyncConnection":
:return: Database connection.
"""
try:
connection = self.connection
except exceptions.SQLAlchemyError:
self._connection = connection = self._engine.connect()
await self._connection.__aenter__()
with self._lock:
async with self._alock:
try:
connection = self.connection
except exceptions.SQLAlchemyError:
self._connection = connection = self._engine.connect()
await self._connection.__aenter__()

self._connection_clients += 1
return connection
self._connection_clients += 1
return connection

async def close(self, connection: "AsyncConnection") -> None:
"""Close the connection to the database.
Expand All @@ -140,14 +151,16 @@ async def close(self, connection: "AsyncConnection") -> None:
:param connection: Database connection.
:raises SQLAlchemyError: If the connection is a different connection from the one opened.
"""
if connection != self.connection:
raise exceptions.SQLAlchemyError("Wrong connection")
with self._lock:
async with self._alock:
if connection != self.connection:
raise exceptions.SQLAlchemyError("Wrong connection")

self._connection_clients -= 1
self._connection_clients -= 1

if self._connection_clients == 0:
await connection.__aexit__(None, None, None)
self._connection = None
if self._connection_clients == 0:
await connection.__aexit__(None, None, None)
self._connection = None

async def begin(self, connection: "AsyncConnection") -> "AsyncTransaction":
"""Begin a new transaction.
Expand Down Expand Up @@ -211,25 +224,35 @@ async def open(self) -> "AsyncConnection":
:return: Database connection.
"""
connection = self._engine.connect()
await connection.__aenter__()
self._connections.add(connection)
return connection
logger.debug("Start opening connection")
with self._lock_open:
connection = self._engine.connect()
logger.debug("Connection created %s", hash(connection))
if connection in self._connections:
raise ValueError("Connection already on use")

await connection.__aenter__()
self._connections.add(connection)
logger.debug("Open connection %s", hash(connection))
return connection

async def close(self, connection: "AsyncConnection") -> None:
"""Close the connection to the database.
:param connection: Database connection.
:raises SQLAlchemyError: If the connection is not initialized.
"""
if connection not in self._connections:
raise exceptions.SQLAlchemyError("Connection not initialized")
logger.debug("Start closing connection %s", hash(connection))
with self._lock_open:
if connection not in self._connections:
raise exceptions.SQLAlchemyError("Connection not initialized")

if connection in self._transactions:
await self.end(self._transactions[connection])
if connection in self._transactions:
await self.end(self._transactions[connection])

await connection.__aexit__(None, None, None)
self._connections.remove(connection)
self._connections.remove(connection)
await connection.__aexit__(None, None, None)
logger.debug("Close connection %s", hash(connection))

async def begin(self, connection: "AsyncConnection") -> "AsyncTransaction":
"""Begin a new transaction.
Expand Down Expand Up @@ -258,13 +281,13 @@ async def end(self, transaction: "AsyncTransaction", *, rollback: bool = False)
if transaction.connection not in self._transactions:
raise exceptions.SQLAlchemyError("Transaction not started")

del self._transactions[transaction.connection]

if rollback:
await transaction.rollback()
else:
await transaction.commit()

del self._transactions[transaction.connection]


class SQLAlchemyModule(Module):
"""SQLAlchemy module.
Expand Down

0 comments on commit 400dd85

Please sign in to comment.