Skip to content

Commit

Permalink
adapt to sqlalchemy 2.0 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
karlicoss committed Jan 28, 2023
1 parent 6fdd070 commit 4514ed1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
32 changes: 20 additions & 12 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pkg_resources import get_distribution, DistributionNotFound

try:
Expand Down Expand Up @@ -34,8 +36,8 @@

import appdirs # type: ignore[import]

import sqlalchemy # type: ignore[import]
from sqlalchemy import Column, Table, event
import sqlalchemy
from sqlalchemy import Column, Table, event, text


from .compat import fix_sqlalchemy_StatementError_str
Expand Down Expand Up @@ -180,7 +182,7 @@ def python_type(self): return Exception

def process_literal_param(self, value, dialect): raise NotImplementedError() # make pylint happy

def process_bind_param(self, value: Optional[Exception], dialect) -> Optional[List[Any]]:
def process_bind_param(self, value: Optional[Exception], dialect) -> Optional[List[Any]]: # type: ignore[override]
if value is None:
return None
sargs: List[Any] = []
Expand Down Expand Up @@ -339,9 +341,11 @@ def strip_generic(tp):
NT = TypeVar('NT')
# sadly, bound=NamedTuple is not working yet in mypy
# https://github.com/python/mypy/issues/685
# also needs to support dataclasses?


class NTBinder(NamedTuple):
@dataclasses.dataclass
class NTBinder(Generic[NT]):
"""
>>> class Job(NamedTuple):
... company: str
Expand Down Expand Up @@ -387,7 +391,7 @@ class NTBinder(NamedTuple):
fields : Sequence[Any] # mypy can't handle cyclic definition at this point :(

@staticmethod
def make(tp: Type, name: Optional[str]=None) -> 'NTBinder':
def make(tp: Type[NT], name: Optional[str]=None) -> NTBinder:
tp, optional = strip_optional(tp)
union: Optional[Type]
fields: Tuple[Any, ...]
Expand Down Expand Up @@ -582,9 +586,9 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
# pylint: disable=unused-variable
def do_begin(conn):
# NOTE there is also BEGIN CONCURRENT in newer versions of sqlite. could use it later?
conn.execute('BEGIN DEFERRED')
conn.execute(text('BEGIN DEFERRED'))

self.meta = sqlalchemy.MetaData(self.connection)
self.meta = sqlalchemy.MetaData()
self.table_hash = Table('hash', self.meta, Column('value', sqlalchemy.String))

self.binder = NTBinder.make(tp=cls)
Expand All @@ -593,10 +597,10 @@ def do_begin(conn):
# temporary table, we use it to insert and then (atomically?) rename to the above table at the very end
self.table_cache_tmp = Table('cache_tmp', self.meta, *self.binder.columns)

def __enter__(self):
def __enter__(self) -> 'DbHelper':
return self

def __exit__(self, *args):
def __exit__(self, *args) -> None:
self.connection.close()


Expand Down Expand Up @@ -939,15 +943,19 @@ def cachew_wrapper(
table_cache_tmp = db.table_cache_tmp

# first, try to do as much as possible read-only, benefiting from deferred transaction
old_hashes: Sequence
try:
# not sure if there is a better way...
old_hashes = conn.execute(db.table_hash.select()).fetchall()
cursor = conn.execute(db.table_hash.select())
except sqlalchemy.exc.OperationalError as e:
# meh. not sure if this is a good way to handle this..
if 'no such table: hash' in str(e):
old_hashes = []
else:
raise e
else:
old_hashes = cursor.fetchall()


assert len(old_hashes) <= 1, old_hashes # shouldn't happen
old_hash: Optional[SourceHash]
Expand Down Expand Up @@ -1035,7 +1043,7 @@ def missing_keys(cached: List[str], wanted: List[str]) -> Optional[List[str]]:

# 'table' used to be old 'cache' table name, so we just delete it regardless
# otherwise it might overinfalte the cache db with stale values
conn.execute(f'DROP TABLE IF EXISTS `table`')
conn.execute(text(f'DROP TABLE IF EXISTS `table`'))

# NOTE: we have to use .drop and then .create (e.g. instead of some sort of replace)
# since it's possible to have schema changes inbetween calls
Expand Down Expand Up @@ -1095,7 +1103,7 @@ def flush() -> None:

# meh https://docs.sqlalchemy.org/en/14/faq/metadata_schema.html#does-sqlalchemy-support-alter-table-create-view-create-trigger-schema-upgrade-functionality
# also seems like sqlalchemy doesn't have any primitives to escape table names.. sigh
conn.execute(f"ALTER TABLE `{table_cache_tmp.name}` RENAME TO `{table_cache.name}`")
conn.execute(text(f"ALTER TABLE `{table_cache_tmp.name}` RENAME TO `{table_cache.name}`"))

# pylint: disable=no-value-for-parameter
conn.execute(db.table_hash.insert().values([{'value': new_hash}]))
Expand Down
23 changes: 13 additions & 10 deletions src/cachew/compat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import sys
def fix_sqlalchemy_StatementError_str():

def fix_sqlalchemy_StatementError_str() -> None:
# see https://github.com/sqlalchemy/sqlalchemy/issues/5632
import sqlalchemy # type: ignore
v = sqlalchemy.__version__
import sqlalchemy
v = sqlalchemy.__version__ # type: ignore[attr-defined]
if v != '1.3.19':
# sigh... will still affect smaller versions.. but patching code to remove import dynamically would be far too mad
return

from sqlalchemy.util import compat # type: ignore
from sqlalchemy.exc import StatementError as SE # type: ignore
from sqlalchemy.util import compat
from sqlalchemy.exc import StatementError as SE

def _sql_message(self, as_unicode):
details = [self._message(as_unicode=as_unicode)]
if self.statement:
if not as_unicode and not compat.py3k:
stmt_detail = "[SQL: %s]" % compat.safe_bytestring(
# pylint: disable=no-member
if not as_unicode and not compat.py3k: # type: ignore[attr-defined]
# pylint: disable=no-member
stmt_detail = "[SQL: %s]" % compat.safe_bytestring( # type: ignore[attr-defined]
self.statement
)
else:
Expand All @@ -27,9 +30,9 @@ def _sql_message(self, as_unicode):
)
else:
# NOTE: this will still cause issues
from sqlalchemy.sql import util # type: ignore
from sqlalchemy.sql import util

params_repr = util._repr_params(
params_repr = util._repr_params( # type: ignore[attr-defined]
self.params, 10, ismulti=self.ismulti
)
details.append("[parameters: %r]" % params_repr)
Expand All @@ -38,4 +41,4 @@ def _sql_message(self, as_unicode):
details.append(code_str)
return "\n".join(["(%s)" % det for det in self.detail] + details)

SE._sql_message = _sql_message
SE._sql_message = _sql_message # type: ignore[attr-defined,assignment]
9 changes: 9 additions & 0 deletions src/cachew/tests/test_cachew.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ class UUU(NamedTuple):
yy: int


def test_simple() -> None:
# just make sure all the high level cachew stuff is working
@cachew
def fun() -> Iterable[UUU]:
yield from []

list(fun())


def test_custom_hash(tdir):
"""
Demo of using argument's modification time to determine if underlying data changed
Expand Down
7 changes: 5 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ minversion = 3.5
envlist = tests,pylint,mypy

[testenv]
passenv = CI CI_* CIRCLE*
passenv =
CI
CI_*
CIRCLE*

[testenv:tests]
commands =
Expand All @@ -17,7 +20,7 @@ commands =
commands =
pip install -e .[testing]
python -m mypy --install-types --non-interactive \
src \
-p cachew \
# txt report is a bit more convenient to view on CI
--txt-report .coverage.mypy \
--html-report .coverage.mypy \
Expand Down

0 comments on commit 4514ed1

Please sign in to comment.