Skip to content

Commit

Permalink
torchdynamo support self.modules() for nn_module (pytorch#88695)
Browse files Browse the repository at this point in the history
This PR allows models to call self.modules() during dynamo tracing.

Pull Request resolved: pytorch#88695
Approved by: https://github.com/voznesenskym
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Nov 12, 2022
1 parent 27dc03e commit 3765621
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
20 changes: 20 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,26 @@ def fn(x):
res = opt_fn(a)
self.assertTrue(same(ref, res))

def test_modules(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(4, 3)

def forward(self, inp):
res = torch.zeros(3, 3)
for mod in self.modules():
res += self.fc(inp)
return res

mod = Foo()
args = (torch.ones(3, 4),)
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod)
self.assertTrue(same(mod(*args), opt_mod(*args)))
self.assertEqual(cnt.op_count, 5)
self.assertEqual(cnt.frame_count, 1)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def __init__(self, expr_to_tensor_ref, id_to_name_map):
self.id_to_name_map = id_to_name_map

def _print_Symbol(self, expr) -> str:
assert isinstance(expr, sympy.core.symbol.Symbol)
assert isinstance(expr, sympy.Symbol)
if expr == 0:
return "0"
if expr == 1:
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def named_embed(name, obj):
):
result.append(named_embed(name, submod))
return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
elif name == "modules":
return wrap_values(module.named_modules())
elif name == "parameters":
return wrap_values(module.named_parameters(**get_kwargs("recurse")))
elif name == "values":
Expand Down

0 comments on commit 3765621

Please sign in to comment.