Skip to content

Commit

Permalink
[Usability] Capture argument names for traced functions and modules (p…
Browse files Browse the repository at this point in the history
…ytorch#51775)

Summary:
Previously `torch.jit.trace` relies on AutoGrad hooks to infer name of tensors in computation, including those of function/method arguments. This often doesn't work out because:

- These names often do not exist
- Tracer uses argument name of first tensor operation on each tensor as inferred argument names. These tensor operations have programmatically-generated names like `argument_1`

This PR extracts argument names directly from Python functions and pass them down to tracer, which then assigns them to correct graph inputs. This way, we always have the correct argument names captured in IR.

This is useful for both debugging and supporting using `InterfaceType` to represent traced modules.

Pull Request resolved: pytorch#51775

Reviewed By: izdeby

Differential Revision: D26273105

Pulled By: gmagogsfm

fbshipit-source-id: 934a385041137dc3731bb6fa8657b11532fed9e5
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Feb 11, 2021
1 parent 4add850 commit 705fa7e
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 46 deletions.
4 changes: 2 additions & 2 deletions test/expect/TestTensorBoard.test_pytorch_graph.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
node {
name: "input/input"
name: "input/x"
op: "IO Node"
attr {
key: "_output_shapes"
Expand Down Expand Up @@ -111,7 +111,7 @@ node {
name: "myLinear/Linear[l]/23"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/20"
input: "input/input"
input: "input/x"
input: "myLinear/Linear[l]/22"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/19"
Expand Down
79 changes: 79 additions & 0 deletions test/jit/test_jit_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import sys
from textwrap import dedent
import unittest

import torch

from torch.testing._internal import jit_utils

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

# Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase):
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
def test_get_callable_argument_names_positional_or_keyword(self):
def fn_positional_or_keyword_args_only(x, y):
return x + y
self.assertEqual(
["x", "y"],
torch._jit_internal.get_callable_argument_names(fn_positional_or_keyword_args_only))

# Tests that POSITIONAL_ONLY arguments are ignored.
@unittest.skipIf(sys.version_info < (3, 8), 'POSITIONAL_ONLY arguments are not supported before 3.8')
def test_get_callable_argument_names_positional_only(self):
code = dedent('''
def fn_positional_only_arg(x, /, y):
return x + y
''')

fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))

# Tests that VAR_POSITIONAL arguments are ignored.
def test_get_callable_argument_names_var_positional(self):
# Tests that VAR_POSITIONAL arguments are ignored.
def fn_var_positional_arg(x, *arg):
return x + arg[0]
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))

# Tests that KEYWORD_ONLY arguments are ignored.
def test_get_callable_argument_names_keyword_only(self):
def fn_keyword_only_arg(x, *, y):
return x + y
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))

# Tests that VAR_KEYWORD arguments are ignored.
def test_get_callable_argument_names_var_keyword(self):
def fn_var_keyword_arg(**args):
return args['x'] + args['y']
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg))

# Tests that a function signature containing various different types of
# arguments are ignored.
@unittest.skipIf(sys.version_info < (3, 8), 'POSITIONAL_ONLY arguments are not supported before 3.8')
def test_get_callable_argument_names_hybrid(self):
code = dedent('''
def fn_hybrid_args(x, /, y, *args, **kwargs):
return x + y + args[0] + kwargs['z']
''')
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
65 changes: 64 additions & 1 deletion test/jit/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \
IS_SANDCASTLE, TemporaryFileName
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \
_tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, RUN_CUDA_MULTI_GPU
_tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, \
RUN_CUDA_MULTI_GPU, make_global
from torch.testing._internal.common_cuda import with_tf32_off
from torch import Tensor

Expand Down Expand Up @@ -1895,6 +1896,44 @@ def forward(self, x):
x = torch.ones(1)
torch.jit.trace(Net(), x)

def test_trace_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
return first_arg + second_arg

traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_fn.graph))

def test_trace_partial_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
return first_arg + second_arg

traced_fn = torch.jit.trace(fn, (torch.ones(1),))
FileCheck().check("first_arg").check_not("second_arg") \
.run(str(traced_fn.graph))

def test_trace_module_argument_names_captured(self):
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
return self.conv(first_arg) + second_arg

m = TestModule()
example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))

# Explicitly tracing module's forward method
traced_module_forward = torch.jit.trace(m.forward, example_input)
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_module_forward.graph))

# Tracing module's directly
traced_module = torch.jit.trace(m, example_input)
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_module.graph))


class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
Expand Down Expand Up @@ -2367,3 +2406,27 @@ def forward(
with self.assertRaisesRegex(RuntimeError, "cannot be understood by the tracer, only outputs matching"):
mod = ReturnsBadDict()
traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False)


def test_traced_module_implements_interface(self):
@torch.jit.interface
class TestModuleInterface(nn.Module):
def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
pass

make_global(TestModuleInterface)

class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
return self.conv(first_arg) + second_arg

def fn_takes_interface(x: TestModuleInterface):
ones = torch.ones(1, 1, 3, 3)
return x.forward(ones, ones)

scripted_test_module = torch.jit.script(TestModule())
self.checkScript(fn_takes_interface, (scripted_test_module,))
25 changes: 6 additions & 19 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from jit.test_cuda import TestCUDA # noqa: F401
from jit.test_hash import TestHash # noqa: F401
from jit.test_complex import TestComplex # noqa: F401
from jit.test_jit_utils import TestJitUtils # noqa: F401

# Torch
from torch import Tensor
Expand Down Expand Up @@ -12518,20 +12519,6 @@ def foo(self, x, y):
test_str.append(str(tm.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))

# Helper function to eval Python3 code without causing a syntax error for
# this file under py2
def _get_py3_code(self, code, fn_name):
with tempfile.TemporaryDirectory() as tmp_dir:
script_path = os.path.join(tmp_dir, 'script.py')
with open(script_path, 'w') as f:
f.write(code)
import importlib.util
spec = importlib.util.spec_from_file_location(fn_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
fn = getattr(module, fn_name)
return fn

# Python AST Frontend , Python 3-style type annotations , Script function
def test_annot_ast_py3_fn(self):
code = dedent('''
Expand All @@ -12545,7 +12532,7 @@ def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
''')
test_str = []
for pair in self.type_input_return_pairs():
fn = self._get_py3_code(self.format_code(code, pair), 'foo')
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
test_str.append(str(fn.schema))
self.assertExpectedStripMangled("\n".join(test_str))

Expand All @@ -12565,7 +12552,7 @@ def foo(x, # type: {input}
test_str = []

for pair in self.type_input_return_pairs():
fn = self._get_py3_code(self.format_code(code, pair), 'foo')
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
args = fn.schema.arguments
returns = fn.schema.returns
self.assertEqual(str(args[0].type), pair[1])
Expand Down Expand Up @@ -12620,7 +12607,7 @@ def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output

test_str = []
for pair in self.type_input_return_pairs():
fn = self._get_py3_code(self.format_code(code, pair), 'instance')
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
test_str.append(str(fn.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))

Expand All @@ -12636,7 +12623,7 @@ def foo(x, y):

test_str = []
for pair in self.type_input_return_pairs():
fn = self._get_py3_code(self.format_code(code, pair), 'foo')
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
test_str.append(str(fn.schema))
self.assertExpected("\n".join(test_str))

Expand All @@ -12654,7 +12641,7 @@ def foo(self, x, y):

test_str = []
for pair in self.type_input_return_pairs():
fn = self._get_py3_code(self.format_code(code, pair), 'instance')
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
test_str.append(str(fn.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))

Expand Down
6 changes: 4 additions & 2 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def _create_function_from_trace(
input_tuple: Tuple[Any, ...],
var_lookup_fn: Callable[[Tensor], str],
strict: _bool,
force_outplace: _bool
force_outplace: _bool,
argument_names: List[str]
) -> Tuple[Graph, Stack]: ...
def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
Expand Down Expand Up @@ -799,7 +800,8 @@ def _create_graph_by_tracing(
var_name_lookup_fn: Callable[[Tensor], str],
strict: Any,
force_outplace: Any,
self: Any = None
self: Any = None,
argument_names: List[str] = []
) -> Tuple[Graph, Stack]: ...
def _tracer_warn_use_python(): ...
def _get_tracing_state() -> TracingState: ...
Expand Down
31 changes: 31 additions & 0 deletions torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,37 @@ def can_compile_class(cls):
return all(has_code)


def get_callable_argument_names(fn):
"""
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
Returns an empty list when other types of arguments are present.
This is used by `torch.jit.trace` to assign meaningful argument names to
traced functions and modules.
Args:
fn: A callable.
Returns:
Argument names: List[str]
"""
# inspect.signature may fail, give up in that case.
try:
callable_signature = inspect.signature(fn)
except Exception:
return []

argument_names = []
for name, param in callable_signature.parameters.items():
# All four other types of arguments do not map to individual values
# with a keyword as name.
if not param.kind == param.POSITIONAL_OR_KEYWORD:
return []

argument_names.append(name)

return argument_names


def get_annotation_str(annotation):
"""
Convert an AST node containing a type annotation to the string present in the source
Expand Down
28 changes: 24 additions & 4 deletions torch/csrc/jit/frontend/tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ static IValue addInput(
if (state->hasValue(input)) {
input_tensor = input_tensor.view(input_tensor.sizes());
}
value->setDebugName(name);
if (!value->hasDebugName()) {
value->setDebugName(name);
}
state->setValue(input_tensor, value);
return input_tensor;
} else if (auto tuple_type = type->cast<TupleType>()) {
Expand Down Expand Up @@ -440,7 +442,8 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
std::function<std::string(const Variable&)> var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self) {
Module* self,
const std::vector<std::string>& argument_names) {
try {
// Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables
Expand All @@ -459,9 +462,26 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
}

for (IValue& input : inputs) {
input = addInput(state, input, input.type(), state->graph->addInput());
// When enough argument name hints are provided, use them as debug names
// for traced function/modules.
// Here argument_names is allowed to have more names than needed because
// some arguments may have valid default values, therefore they don't need
// example inputs.
if (argument_names.size() >= inputs.size()) {
for (size_t i = 0, e = inputs.size(); i < e; ++i) {
IValue& input = inputs[i];
input = addInput(
state,
input,
input.type(),
state->graph->addInput(argument_names[i]));
}
} else {
for (IValue& input : inputs) {
input = addInput(state, input, input.type(), state->graph->addInput());
}
}

auto graph = state->graph;

getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/frontend/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
std::function<std::string(const Variable&)> var_name_lookup_fn,
bool strict = true,
bool force_outplace = false,
Module* self = nullptr);
Module* self = nullptr,
const std::vector<std::string>& argument_names = {});

TORCH_API void abandon();

Expand Down
Loading

0 comments on commit 705fa7e

Please sign in to comment.