Skip to content

Commit

Permalink
Improve symbolic hack a bit (pytorch#4143)
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhulgakov authored and apaszke committed Dec 16, 2017
1 parent 2e08885 commit cab5921
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 64 deletions.
50 changes: 22 additions & 28 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, inplace=False):
self.inplace = inplace


def _nested_map(condition, fn):
def _nested_map(condition, fn, condition_msg=None):
def _map(obj):
if condition(obj):
return fn(obj)
Expand All @@ -261,12 +261,16 @@ def _map(obj):
elif isinstance(obj, (list, tuple)):
return type(obj)(_map(x) for x in obj)
else:
raise ValueError("NestedIOFunction doesn't know how to process "
"an input object of type " + torch.typename(obj))
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))

return _map


def _iter_filter(condition):
def _iter_filter(condition, skip_unknown=False, condition_msg=None):
def _iter(obj):
if condition(obj):
yield obj
Expand All @@ -276,9 +280,13 @@ def _iter(obj):
for o in obj:
for var in _iter(o):
yield var
else:
raise ValueError("NestedIOFunction doesn't know how to process "
"an input object of type " + torch.typename(obj))
elif not skip_unknown:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))

return _iter


Expand All @@ -297,27 +305,13 @@ def unflatten_helper(input, proto):
return unflatten_helper(input, proto)[0]


# Return suitable 'prototype' that doesn't hold
# references possibly big options from 'obj'
def _to_proto(obj):
def helper(obj):
if isinstance(obj, torch.autograd.Variable):
return "HOLE"
elif obj is None:
return None
elif isinstance(obj, (list, tuple)):
type_ = type(obj)
return type_(helper(o) for o in obj)
else:
raise ValueError("NestedIOFunction doesn't know how to process "
"an input object of type " + torch.typename(obj))
return helper(obj)


_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable))
_iter_tensors = _iter_filter(torch.is_tensor)
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), condition_msg="Variables")
_iter_variables_permissive = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), skip_unknown=True)
_iter_jit_values = _iter_filter(lambda o: isinstance(o, torch._C.Value), condition_msg="jit's Values")
_iter_tensors = _iter_filter(torch.is_tensor, condition_msg="Tensors")
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o), condition_msg="Tensors or None")
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable),
lambda o: o.data, condition_msg="Variables")


class NestedIOFunction(Function):
Expand Down
5 changes: 2 additions & 3 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.autograd.function as function
import torch._C
from torch import Tensor
from torch.autograd import Variable
from torch.autograd import Variable, function
from torch.nn import Module, ParameterList, Parameter
from torch._six import raise_from
from collections import defaultdict, OrderedDict
Expand Down Expand Up @@ -287,7 +286,7 @@ def clone_input(a):
else:
return a.clone()
return function._nested_map(lambda o: isinstance(o, Variable) or torch.is_tensor(o),
clone_input)(args)
clone_input, condition_msg="Variables")(args)


# This is purely for developer debugging. We are not going to advertise it.
Expand Down
44 changes: 15 additions & 29 deletions torch/nn/_functions/rnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from torch.autograd import Function, NestedIOFunction, Variable
from torch.autograd.function import _iter_variables, _unflatten
from torch.autograd import NestedIOFunction
import torch.backends.cudnn as cudnn
from .. import functional as F
from .thnn import rnnFusedPointwise as fusedBackend
Expand Down Expand Up @@ -343,32 +342,18 @@ def backward_extended(self, grad_output, grad_hy):
return grad_input, grad_weight, grad_hx


def hack_onnx_rnn(fargs, output, args, kwargs):
input, all_weights, hx = fargs
output_tensors = tuple(v.data for v in _iter_variables(output))
flat_weights = tuple(_iter_variables(all_weights))
flat_hx = tuple(_iter_variables(hx))
def RNN_symbolic_builder(*args, **kwargs):
def symbolic(g, input, all_weights, hx, **kwargs):
# Something can go here, e.g.
# return g.op('LSTM', input, *all_weights[0], outputs=2)
raise RuntimeError("RNN symbolic NYI")

class RNNSymbolic(Function):
@staticmethod
def symbolic(g, *fargs):
# NOTE: fargs contains Variable inputs (input + weight + hidden)
# NOTE: args/kwargs contain RNN parameters
raise RuntimeError("hack_onnx_rnn NYI")

@staticmethod
def forward(ctx, *fargs):
return output_tensors

@staticmethod
def backward(ctx, *gargs, **gkwargs):
raise RuntimeError("FIXME: Traced RNNs don't support backward")

flat_output = RNNSymbolic.apply(*((input,) + flat_weights + flat_hx))
return _unflatten(flat_output, output)
import torch.onnx
return torch.onnx.symbolic_override(symbolic)


def RNN(*args, **kwargs):

def forward(input, *fargs, **fkwargs):
if cudnn.is_acceptable(input.data):
func = CudnnRNN(*args, **kwargs)
Expand All @@ -377,11 +362,12 @@ def forward(input, *fargs, **fkwargs):

# Hack for the tracer that allows us to represent RNNs as single
# nodes and export them to ONNX in this form
# It can be also used as a decorator at the higher level
# Check the first argument explicitly to reduce the overhead of creating
# the lambda
if torch._C._jit_is_tracing(input):
assert not fkwargs
output = func(input, *fargs)
return hack_onnx_rnn((input,) + fargs, output, args, kwargs)
else:
return func(input, *fargs, **fkwargs)
func = RNN_symbolic_builder(*args, **kwargs)(func)

return func(input, *fargs, **fkwargs)

return forward
79 changes: 75 additions & 4 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import torch.serialization
import re
import collections
import string
import json
import math
import contextlib
import numbers
import warnings
from torch._utils import _range
import functools
import types
from torch._six import string_classes
from torch.autograd import Function, function


@contextlib.contextmanager
Expand Down Expand Up @@ -84,6 +83,11 @@ def export(model, args, f, export_params=True, verbose=False, training=False,


def _optimize_trace(trace, aten):
# run dce first to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
torch._C._jit_pass_dce(trace)
torch._C._jit_pass_lint(trace)

torch._C._jit_pass_peephole(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace, aten)
Expand Down Expand Up @@ -367,6 +371,73 @@ def _node_getitem(self, k):
return getattr(self, sel)(k)


def symbolic_override(symbolic_fn):
"""
Decorator to override ONNX export of the a function with specified subgraph.
Effectively allows to attach symbolic() implementation to an arbitrary
python function or autograd.Function. Requirements for the decorated
function:
- being non-member function or autograd.Function
- positional inputs are Variables/Tensors or (nested) lists or tuples of
them (similar requirement to NestedIOFunction)
- outputs are similarly Variables/Tensors or (nested) lists or tuples of
them
- keyword arguments are of non-tensor type
Example usage:
```
def symb(g, x, y):
return g.op('Sum', x, y[0], y[1])
@symbolic_override(symb)
def foo(x, y):
return x + y[0] + y[1]
```
"""

def wrapper_maker(fn):

def wrapper(*args, **kwargs):
output = fn(*args, **kwargs)
flat_args = tuple(function._iter_variables(args))
if not any(map(torch._C._jit_is_tracing, flat_args)):
return output
flat_output_tensors = tuple(
v.data for v in function._iter_variables(output))
assert len(list(function._iter_variables_permissive(
list(kwargs.values())))) == 0, \
"Passing Variable through kwargs is not supported"

class ExportProxy(Function):
@staticmethod
def symbolic(g, *flat_args):
symbolic_args = function._unflatten(flat_args, args)
symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
return tuple(function._iter_jit_values(symbolic_output))

@staticmethod
def forward(ctx, *unused_args):
return flat_output_tensors

@staticmethod
def backward(ctx, *unused_args, **unused_kwargs):
raise RuntimeError(
"symbolic_override is meant for inference export only")

flat_proxy_output = ExportProxy.apply(*flat_args)
return function._unflatten(flat_proxy_output, output)

# fn might be autograd.Function too, in this case wrapping doesn't work
if isinstance(fn, types.FunctionType):
wrapper = functools.wraps(fn)(wrapper)

return wrapper

return wrapper_maker


torch._C.Graph.op = _graph_op
torch._C.Graph.at = _graph_at
torch._C.Graph.constant = _graph_constant
Expand Down

0 comments on commit cab5921

Please sign in to comment.