Skip to content

Commit

Permalink
Rewrite assert statement with torch._assert under config (pytorch#88246)
Browse files Browse the repository at this point in the history
This diff rewrites assert statement in python with torch._assert under config. The resulting graph looks something like:
```
SOURCE CODE:
def f(x):
      assert x[0] == 3
      return x.cos()

CAPTURED GRAPH:
graph():
    %arg0 : [#users=2] = placeholder[target=arg0]
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%arg0, 0), kwargs = {})
    %eq : [#users=1] = call_function[target=operator.eq](args = (%getitem, 3), kwargs = {})
    %_assert : [#users=0] = call_function[target=torch._assert](args = (%eq, "assertion_error"), kwargs = {})
    %cos : [#users=1] = call_method[target=cos](args = (%arg0,), kwargs = {})
    return cos
 ```
Note that this introduces side-effect as it could error out while executing graph, but the assertion can eliminated via DCE if we choose to ignore it.

Pull Request resolved: pytorch#88246
Approved by: https://github.com/jansel
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Nov 17, 2022
1 parent af448e8 commit 04169c5
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
92 changes: 92 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,98 @@ def fn(x):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3, "First dim need to be 3"
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
cnt = torch._dynamo.testing.CompileCounter()

opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(*args), opt_f(*args)))
self.assertEqual(cnt.op_count, 6)
self.assertEqual(cnt.frame_count, 1)

exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_not_rewrite_assert_for_other_errors(self):
def f(x):
b = x.sin()
if not x.sum() <= 3:
raise ValueError("input sum needs to be 3")
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
opt_fn = torch._dynamo.optimize("eager")(f)
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
opt_fn(*args)

# TODO (tmanlaibaatar) handle data-dependent fstring in assert statement.
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_fstring_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3, f"First dim need to be {x[0]}"
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_without_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_noop(self):
def f(x):
b = x.sin()
assert True
assert x.dtype == torch.float32
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(*args), opt_f(*args)))
# torch._assert shouldn't be in the graph
self.assertEqual(cnt.op_count, 3)
self.assertEqual(cnt.frame_count, 1)

exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False)
def test_not_rewrite_assert(self):
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b

with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
torch._dynamo.export(f, torch.Tensor([3, 4, 5]))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
# if an exception is encountered
replay_record_enabled = False

# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert = True

# Show a warning on every graph break
print_graph_breaks = False

Expand Down
94 changes: 94 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
fake_tensors_available,
graph_break_dup_warning_checker,
istype,
proxy_args_kwargs,
)
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
Expand Down Expand Up @@ -121,10 +122,103 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction):
return impl


def _detect_and_normalize_assert_statement(
self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool
):
# Detect if this jump instruction is assert and normalize the assert
# by pushing dummy error message when nothing is given.
#
# Python 3.9 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_ASSERTION_ERROR
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS
#
# Python 3.8 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_GLOBAL 0 (Assertion type)
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS 1

if (truth_fn is not operator.truth) or push:
return False

current_instruction_pointer = self.instruction_pointer
inst = self.instructions[current_instruction_pointer]
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
if sys.version_info < (3, 9):
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
return False
else:
if inst.opname != "LOAD_ASSERTION_ERROR":
return False

current_instruction_pointer += 1

if current_instruction_pointer >= len(self.instructions):
return False

inst = self.instructions[current_instruction_pointer]
has_error_msg = False
# DETECT RAISE_VARARGS or LOAD CONST
if inst.opname == "LOAD_CONST":
if not isinstance(inst.argval, str):
return False
self.LOAD_CONST(inst)
has_error_msg = True

# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]
if inst.opname != "CALL_FUNCTION":
return False

# CALL_FUNCTION should be followed by RAISE_VARARGS
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]

if inst.opname != "RAISE_VARARGS":
return False

if not has_error_msg:
# Push dummy value instead of error message
self.push(ConstantVariable("assertion error"))

return True


def generic_jump(truth_fn: typing.Callable, push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if (
config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push)
):
error_msg: VariableTracker = self.pop()
self.output.guards.update(error_msg.guards)
# Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst)
return

# Manually insert torch._assert instead of python assert and jump over
# assert related instructions as we don't need them anymore.
self.output.create_proxy(
"call_function",
torch._assert,
*proxy_args_kwargs((value, error_msg), {}),
current_tx=self,
)
self.jump(inst)
return

if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
Expand Down

0 comments on commit 04169c5

Please sign in to comment.