Skip to content

Commit

Permalink
Process inputs and outputs in fx interpreter (pytorch#74242)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#74242

The inputs and outputs of the graph module might be different from the graph inputs and outputs if users are using custom codegen. In interpreter, it runs the graph instead of the generated forward function so it might not work if user provides the inputs to the graph module. To fill the gap, we call `process_inputs` and `process_outputs` inside interpreter.

Test Plan: unit test: test_interpreter_with_codegen

Reviewed By: jamesr66a, Chillee

Differential Revision: D34898108

fbshipit-source-id: 250bd236f6c8c1268a363cf19a09521a4f64b3a9
(cherry picked from commit b33076f)
  • Loading branch information
842974287 authored and pytorchmergebot committed Mar 22, 2022
1 parent 65329f4 commit f65594f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
32 changes: 32 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3409,6 +3409,38 @@ def f(a, b):
transformed_gm = Transformer(nf).transform()
self.assertEqual(nf(vals), transformed_gm(vals))

def test_interpreter_with_codegen(self):
class ListCodeGen(CodeGen):
def gen_fn_def(self, free_vars, maybe_return_annotation):
lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
{', '.join(free_vars)} = args_list"""
return lst_unpack

def additional_globals(self):
return [('List', typing.List)]

def process_inputs(self, *inputs):
assert(len(inputs) == 1)
return inputs[0]

def generate_output(self, output_args):
return f'return list({repr(output_args)})'

def process_outputs(self, outputs):
return list(outputs)

def f(a, b):
a = a + b
b = a + b
return a, b

nf = symbolic_trace(f)
vals = [torch.randn(3), torch.randn(3)]
nf.graph.set_codegen(ListCodeGen())
nf.recompile()
self.assertEqual(Interpreter(nf).run(vals), nf(vals))

def test_imul_code_print(self):
graph = torch.fx.Graph()
a = graph.placeholder("a")
Expand Down
8 changes: 5 additions & 3 deletions torch/fx/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def register_last_uses(n : Node, user : Node):
map_arg(node.kwargs, lambda n: register_last_uses(n, node))

@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any:
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
"""
Run `module` via interpretation and return the result.
Expand All @@ -108,6 +108,8 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any:
# Positional function args are consumed left-to-right by
# `placeholder` nodes. Use an iterator to keep track of
# position and extract those values.
if enable_io_processing:
args = self.module.graph.process_inputs(*args)
self.args_iter : Iterator[Any] = iter(args)

for node in self.module.graph.nodes:
Expand All @@ -126,7 +128,7 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any:

if node.op == 'output':
output_val = self.env[node]
return output_val
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val

@compatibility(is_backward_compatible=True)
def run_node(self, n : Node) -> Any:
Expand Down Expand Up @@ -447,7 +449,7 @@ def transform(self) -> GraphModule:
Transform ``self.module`` and return the transformed
``GraphModule``.
"""
result = super().run()
result = super().run(enable_io_processing=False)
if result is not None:
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
return a.node if isinstance(a, Proxy) else a
Expand Down

0 comments on commit f65594f

Please sign in to comment.