Skip to content

Commit

Permalink
[fx] throw exceptions on invalid input in FloorDiv (pytorch#93143)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#93143
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Feb 3, 2023
1 parent ba614f3 commit 34bcbfb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 4 deletions.
76 changes: 72 additions & 4 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,66 @@ def test_floordiv_float_int(self):
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))

@skipIfNoSympy
def test_floordiv_bool(self):
values = (
(False, True),
(True, 2.5),
(2.5, True),
(False, 7),
(7, True),
)

for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# Compares to int since our FloorDiv has no bool support
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(int(x), int(y)))
# Tests that our impl throws
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))

@skipIfNoSympy
def test_floordiv_complex(self):
values = (
(1.5 + 2.5j, 1.3 + 3.5j),
(1.5 + 2.5j, 2.5),
(2.5, 1.5 + 2.5j),
(1.5 + 2.5j, 7),
(7, 1.5 + 2.5j),
)

for x, y in TestFloorDiv.yield_test_cases(values):
# We don't test error messages to avoid depending on Python
# interpreter version
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))

@skipIfNoSympy
def test_floordiv_div_by_zero(self):
values = (
(2.5, 0),
(2.1, 0.0),
(2.3, sympy.Symbol("s", zero=True)),
)

for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# We don't test error messages to avoid depending on Python
# interpreter version
if type(y) is not sympy.Symbol:
self.assertRaises(ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
ZeroDivisionError,
"division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y))

@skipIfNoSympy
def test_floordiv_zero_base(self):
values = (
Expand Down Expand Up @@ -723,11 +783,22 @@ def test_floordiv_assumptions(self):
)

for base, divisor in itertools.product(cases, repeat=2):
op = FloorDiv(base, divisor)
def op():
return FloorDiv(base, divisor)

def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex

if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"),
op)
continue

op = op()

# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
Expand All @@ -737,9 +808,6 @@ def is_complex(x):
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif is_complex(base) or is_complex(divisor):
self.assertEqual(op.is_integer, False)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
Expand Down
15 changes: 15 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,21 @@ def _eval_is_integer(self):
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base, divisor):
def check_supported_type(x):
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
raise TypeError(
f"unsupported operand type(s) for //: "
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
f", expected integer or real")

check_supported_type(base)
check_supported_type(divisor)

# We don't provide the same error message as in Python because SymPy
# makes it difficult to check the types.
if divisor.is_zero:
raise ZeroDivisionError("division by zero")

if base.is_zero:
return sympy.S.Zero
if base.is_integer and divisor == 1:
Expand Down

0 comments on commit 34bcbfb

Please sign in to comment.