Skip to content

Commit

Permalink
Merge pull request sympy#21799 from asmeurer/no-cache-guard
Browse files Browse the repository at this point in the history
Don't try to guard against non-hashable results in the cache
  • Loading branch information
asmeurer authored Aug 2, 2021
2 parents d1ea82c + f5563b2 commit aa538a2
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sympy.codegen.futils import render_as_module as f_module
from sympy.codegen.pyutils import render_as_module as py_module
from sympy.external import import_module
from sympy.printing import ccode
from sympy.printing.codeprinter import ccode
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip, raises
Expand Down
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile

from sympy.external import import_module
from sympy.printing import ccode
from sympy.printing.codeprinter import ccode
from sympy.utilities._compilation import compile_link_import_strings, has_c
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip
Expand Down
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_cnodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sympy.core.symbol import symbols
from sympy.printing import ccode
from sympy.printing.codeprinter import ccode
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
from sympy.codegen.cnodes import (
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
Expand Down
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_cxxnodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sympy.core.symbol import Symbol
from sympy.codegen.ast import Type
from sympy.codegen.cxxnodes import using
from sympy.printing import cxxcode
from sympy.printing.codeprinter import cxxcode

x = Symbol('x')

Expand Down
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_fnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sympy.codegen.futils import render_as_module
from sympy.core.expr import unchanged
from sympy.external import import_module
from sympy.printing import fcode
from sympy.printing.codeprinter import fcode
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip, XFAIL
Expand Down
2 changes: 1 addition & 1 deletion sympy/codegen/tests/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sympy import log, exp, cos, S, Symbol, Pow, sin, MatrixSymbol, sinc, pi
from sympy.assumptions import assuming, Q
from sympy.external import import_module
from sympy.printing import ccode
from sympy.printing.codeprinter import ccode
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
Expand Down
2 changes: 1 addition & 1 deletion sympy/core/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
logflags = dict(deep=True, log=True, mul=False, power_exp=False,
power_base=False, multinomial=False, basic=False, force=False,
factor=False)
old = old.expand(logflags)
old = old.expand(**logflags)
expr = expand_mul(old)

if not expr.is_Add:
Expand Down
4 changes: 3 additions & 1 deletion sympy/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def func_wrapper(func):
def wrapper(*args, **kwargs):
try:
retval = cfunc(*args, **kwargs)
except TypeError:
except TypeError as e:
if not e.args or not e.args[0].startswith('unhashable type:'):
raise
retval = func(*args, **kwargs)
return retval

Expand Down
35 changes: 34 additions & 1 deletion sympy/core/tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sympy.core.cache import cacheit

from sympy.testing.pytest import raises

def test_cacheit_doc():
@cacheit
Expand All @@ -21,3 +21,36 @@ def testit(x):
assert testit(a) == {}
a[1] = 2
assert testit(a) == {1: 2}

def test_cachit_exception():
# Make sure the cache doesn't call functions multiple times when they
# raise TypeError

a = []

@cacheit
def testf(x):
a.append(0)
raise TypeError

raises(TypeError, lambda: testf(1))
assert len(a) == 1

a.clear()
# Unhashable type
raises(TypeError, lambda: testf([]))
assert len(a) == 1

@cacheit
def testf2(x):
a.append(0)
raise TypeError("Error")

a.clear()
raises(TypeError, lambda: testf2(1))
assert len(a) == 1

a.clear()
# Unhashable type
raises(TypeError, lambda: testf2([]))
assert len(a) == 1

0 comments on commit aa538a2

Please sign in to comment.