Skip to content

Commit

Permalink
Enable tracing of tensor factories with an out argument
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#12051

Differential Revision: D10044890

Pulled By: apaszke

fbshipit-source-id: 2d794bf408875600bc71f354f0b4961d6b715094
  • Loading branch information
apaszke authored and facebook-github-bot committed Sep 26, 2018
1 parent b535aec commit 18f9c07
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
7 changes: 6 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,13 +1224,18 @@ def run(**kwargs):

def fn(x):
return x + torch.ones(2, 3, **kwargs)
input = torch.ones(2, 3, **kwargs)

input_kwargs = kwargs.copy()
if 'out' in input_kwargs:
del input_kwargs['out']
input = torch.ones(2, 3, **input_kwargs)
self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
# check we recorded 'ones' and did not just record a constant
tfn = torch.jit.trace(fn, input)
self.assertTrue("ones" in str(tfn.graph))
run()
run(dtype=torch.int, inputs_require_grads=False)
run(out=torch.tensor([]))
if RUN_CUDA:
run(device="cuda:0")
if RUN_CUDA_MULTI_GPU:
Expand Down
40 changes: 33 additions & 7 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
jit::tracer::ensureUnique("${name}", ${mutable_input});
""")

ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${input}", ${input});""")
ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")

POST_RECORD_TRACE = CodeTemplate("""\
if (tracer_state) {
Expand All @@ -154,6 +154,18 @@
""")


FACTORY_FUNCTION_NAMES = None


def find_factory_functions(declarations):
global FACTORY_FUNCTION_NAMES
FACTORY_FUNCTION_NAMES = set()

for declaration in declarations:
if any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments']):
FACTORY_FUNCTION_NAMES.add(declaration['api_name'])


def should_trace(declaration):
# Operations involving Storage or Type are not traceable at the moment
if any(arg['simple_type'] in {'Storage', 'Type'} for arg in declaration['arguments']):
Expand Down Expand Up @@ -185,17 +197,30 @@ def record_trace_outputs(declaration):

def format_trace(declaration):
local = {}
local['trace_name'] = trace_name = uninplace_api_name(declaration['api_name'])

# *_out functions take the result as a first argument, but since we're
# going to de-inplace the call, we need to remove it from the argument list
trace_inputs = declaration['arguments']
if declaration['name'].endswith('_out'):
trace_inputs = trace_inputs[1:]
trace_input_spec = [(i['name'], i['name']) for i in trace_inputs]

# factories are a bit special because their out-of-place overloads
# take an extra TensorOptions argument, which is missing in the _out function
has_factory_name = trace_name in FACTORY_FUNCTION_NAMES
is_out_overload = any(arg['name'] == 'result' for arg in declaration['arguments'])
if has_factory_name and is_out_overload:
trace_input_spec.append(('result', 'result.options()'))

local['add_trace_inputs'] = \
'\n'.join(ADD_TRACE_INPUT.substitute(name=name, input=value) for name, value in trace_input_spec)

add_trace_inputs = []
for argument in declaration['arguments']:
add_trace_inputs.append(ADD_TRACE_INPUT.substitute(input=argument['name']))
local['add_trace_inputs'] = '\n'.join(add_trace_inputs)
local['inplace_guard'] = ''
# Record inplace operations as out-of-place operations (e.g.,
# not add_ but add)
# TODO: Add a proper concept of side effects to the IR, and
# properly record inplace operations.
local['trace_name'] = uninplace_api_name(declaration['api_name'])
local['inplace_guard'] = ''
if local['trace_name'] != declaration['api_name']:
local['inplace_guard'] = INPLACE_GUARD.substitute(name=declaration['api_name'],
mutable_input=declaration['arguments'][0]['name'])
Expand All @@ -214,6 +239,7 @@ def gen_variable_type(out, aten_declarations, template_path):
implementation of each function dispatches to the base tensor type to
compute the output. The grad_fn is attached to differentiable functions.
"""
find_factory_functions(aten_declarations)

VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h')
VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp')
Expand Down

0 comments on commit 18f9c07

Please sign in to comment.