Skip to content

Commit

Permalink
Add should_traverse_fn to torch.fx.node.map_aggregate (pytorch#81510)
Browse files Browse the repository at this point in the history
Adds an optional callback that checks if map_aggregate should continue recursive traversal. The main motivation is to not traverse torch.Size which is tuple

Pull Request resolved: pytorch#81510
Approved by: https://github.com/SherlockNoMad, https://github.com/jamesr66a
  • Loading branch information
pbelevich authored and pytorchmergebot committed Jul 15, 2022
1 parent 7af0200 commit d52f8c2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node')
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument], should_traverse_fn: Optional[Callable[[torch.fx.node.Argument], bool]] = None) -> torch.fx.node.Argument
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
Expand Down
38 changes: 38 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3501,6 +3501,44 @@ def test_imul_code_print(self):
self.assertEqual(gm(2, 3), 6)
self.assertIn("a *= b", gm.code)

def test_map_aggregate_doesnt_traverse_size(self):
def dont_traverse_size(a):
return type(a) != torch.Size

size = torch.Size([1, 2, 3])

res = torch.fx.node.map_aggregate(size, lambda a: a)
self.assertEqual(type(res), tuple)
self.assertEqual(res, (1, 2, 3))

res = torch.fx.node.map_aggregate(size, lambda a: a, dont_traverse_size)
self.assertEqual(type(res), torch.Size)
self.assertEqual(res, size)

data = (torch.empty(3, 4), size,
{'tensor': torch.empty(4, 5), 'size': size, 'list': [size, (size,), torch.empty(5, 6)]})

res = torch.fx.node.map_aggregate(data, lambda a: a)
self.assertEqual(type(res[1]), tuple)
self.assertEqual(res[1], (1, 2, 3))
self.assertEqual(type(res[2]['size']), tuple)
self.assertEqual(res[2]['size'], (1, 2, 3))
self.assertEqual(type(res[2]['list'][0]), tuple)
self.assertEqual(res[2]['list'][0], (1, 2, 3))
self.assertEqual(type(res[2]['list'][1][0]), tuple)
self.assertEqual(res[2]['list'][1][0], (1, 2, 3))

res = torch.fx.node.map_aggregate(data, lambda a: a, dont_traverse_size)
self.assertEqual(type(res[1]), torch.Size)
self.assertEqual(res[1], size)
self.assertEqual(type(res[2]['size']), torch.Size)
self.assertEqual(res[2]['size'], size)
self.assertEqual(type(res[2]['list'][0]), torch.Size)
self.assertEqual(res[2]['list'][0], size)
self.assertEqual(type(res[2]['list'][1][0]), torch.Size)
self.assertEqual(res[2]['list'][1][0], size)



def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
Expand Down
17 changes: 12 additions & 5 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,20 +600,27 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)


@compatibility(is_backward_compatible=True)
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument],
should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument:
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument
"""
if should_traverse_fn and not should_traverse_fn(a):
return fn(a)

if isinstance(a, tuple):
t = tuple(map_aggregate(elem, fn) for elem in a)
t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t)
elif isinstance(a, list):
return immutable_list(map_aggregate(elem, fn) for elem in a)
return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
elif isinstance(a, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items())
elif isinstance(a, slice):
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn),
map_aggregate(a.step, fn, should_traverse_fn))
else:
return fn(a)

0 comments on commit d52f8c2

Please sign in to comment.