Skip to content

Commit

Permalink
Add symbolic_override_first_arg_based (pytorch#4799)
Browse files Browse the repository at this point in the history
* Add symbolic_override_first_arg_based

* flake fix

* comment

* remove comment (keep forgetting about this PR)
  • Loading branch information
dzhulgakov authored and soumith committed Jan 30, 2018
1 parent c011c8b commit 5b43c22
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 59 deletions.
9 changes: 6 additions & 3 deletions torch/nn/_functions/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,16 @@ def forward(input, *fargs, **fkwargs):

# Hack for the tracer that allows us to represent RNNs as single
# nodes and export them to ONNX in this form
# It can be also used as a decorator at the higher level
# Check the first argument explicitly to reduce the overhead of creating
# the lambda
# the lambda. We need special handling here because the forward()
# function gets reconstructed each and every time when RNN() is invoked
# and we don't want to pay the cost of decorator invocation
import torch
if torch._C._jit_is_tracing(input):
import torch.onnx.symbolic
func = torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs)(func)
decorator = torch.onnx.symbolic_override_first_arg_based(
torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs))
func = decorator(func)

return func(input, *fargs, **fkwargs)

Expand Down
95 changes: 56 additions & 39 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,49 @@ def _node_getitem(self, k):
return getattr(self, sel)(k)


def _symbolic_override_wrapper_maker(symbolic_fn, first_arg_only, fn):

def wrapper(*args, **kwargs):
output = fn(*args, **kwargs)
# fast pass
if first_arg_only and not torch._C._jit_is_tracing(args[0]):
return output

flat_args = tuple(function._iter_variables(args))
if not any(map(torch._C._jit_is_tracing, flat_args)):
return output
flat_output_tensors = tuple(
v.data for v in function._iter_variables(output))
assert len(list(function._iter_variables_permissive(
list(kwargs.values())))) == 0, \
"Passing Variable through kwargs is not supported"

class ExportProxy(Function):
@staticmethod
def symbolic(g, *flat_args):
symbolic_args = function._unflatten(flat_args, args)
symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
return tuple(function._iter_jit_values(symbolic_output))

@staticmethod
def forward(ctx, *unused_args):
return flat_output_tensors

@staticmethod
def backward(ctx, *unused_args, **unused_kwargs):
raise RuntimeError(
"symbolic_override is meant for inference export only")

flat_proxy_output = ExportProxy.apply(*flat_args)
return function._unflatten(flat_proxy_output, output)

# fn might be autograd.Function too, in this case wrapping doesn't work
if isinstance(fn, types.FunctionType):
wrapper = functools.wraps(fn)(wrapper)

return wrapper


def symbolic_override(symbolic_fn):
"""
Decorator to override ONNX export of the a function with specified subgraph.
Expand All @@ -400,45 +443,19 @@ def foo(x, y):
```
"""

def wrapper_maker(fn):

def wrapper(*args, **kwargs):
output = fn(*args, **kwargs)
flat_args = tuple(function._iter_variables(args))
if not any(map(torch._C._jit_is_tracing, flat_args)):
return output
flat_output_tensors = tuple(
v.data for v in function._iter_variables(output))
assert len(list(function._iter_variables_permissive(
list(kwargs.values())))) == 0, \
"Passing Variable through kwargs is not supported"

class ExportProxy(Function):
@staticmethod
def symbolic(g, *flat_args):
symbolic_args = function._unflatten(flat_args, args)
symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
return tuple(function._iter_jit_values(symbolic_output))

@staticmethod
def forward(ctx, *unused_args):
return flat_output_tensors

@staticmethod
def backward(ctx, *unused_args, **unused_kwargs):
raise RuntimeError(
"symbolic_override is meant for inference export only")

flat_proxy_output = ExportProxy.apply(*flat_args)
return function._unflatten(flat_proxy_output, output)

# fn might be autograd.Function too, in this case wrapping doesn't work
if isinstance(fn, types.FunctionType):
wrapper = functools.wraps(fn)(wrapper)

return wrapper

return wrapper_maker
return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, False)


def symbolic_override_first_arg_based(symbolic_fn):
"""
Decorator to override ONNX export of the a function with specified subgraph.
Equivalent to `symbolic_override` but checks only the first argument of the
function to figure out whether the tracing is on. Thus the first arg needs
to be a Variable.
"""

return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, True)


torch._C.Graph.op = _graph_op
Expand Down
34 changes: 17 additions & 17 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def RNN_symbolic_builder(cell_type, *args, **kwargs):
elif cell_type.startswith('RNN_'):
return Elman_RNN_symbolic_builder(cell_type[4:], *args, **kwargs)
else:
return _unimplemented("RNN", "cell type " + cell_type)
return lambda *args, **kwargs: _unimplemented("RNN", "cell type " + cell_type)


def reform_weights(g, w, n, intervals):
Expand Down Expand Up @@ -583,14 +583,14 @@ def symbolic(g, input, all_weights, h0, **fkwargs):


def LSTM_symbolic_builder(input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs):
if batch_first:
return _unimplemented("LSTM", "batch_first")
if dropout:
return _unimplemented("LSTM", "dropout")
if bidirectional:
return _unimplemented("LSTM", "bidirectional")

def symbolic(g, input, all_weights, h0_and_c0, **fkwargs):
if batch_first:
return _unimplemented("LSTM", "batch_first")
if dropout:
return _unimplemented("LSTM", "dropout")
if bidirectional:
return _unimplemented("LSTM", "bidirectional")

h0, c0 = h0_and_c0

# TODO elide this argument to increase parametricity. This is
Expand Down Expand Up @@ -619,18 +619,18 @@ def symbolic(g, input, all_weights, h0_and_c0, **fkwargs):
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
return prev_output, h_outs, None

return torch.onnx.symbolic_override(symbolic)
return symbolic


def GRU_symbolic_builder(input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs):
if batch_first:
return _unimplemented("GRU", "batch_first")
if dropout:
return _unimplemented("GRU", "dropout")
if bidirectional:
return _unimplemented("GRU", "bidirectional")

def symbolic(g, input, all_weights, h0, **fkwargs):
if batch_first:
return _unimplemented("GRU", "batch_first")
if dropout:
return _unimplemented("GRU", "dropout")
if bidirectional:
return _unimplemented("GRU", "bidirectional")

# TODO elide this argument to increase parametricity. This is
# nontrivial because we provide subsequent optional arguments,
# and ONNX does not have a mechanism for skipping non-trailing
Expand All @@ -657,4 +657,4 @@ def symbolic(g, input, all_weights, h0, **fkwargs):
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
return prev_output, h_outs

return torch.onnx.symbolic_override(symbolic)
return symbolic

0 comments on commit 5b43c22

Please sign in to comment.