Skip to content

Commit

Permalink
Support functionalization on torch.cond (pytorch#89966)
Browse files Browse the repository at this point in the history
This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following:

- Output of each branch aliasing input
- In-place mutation on inputs given to each branch

Pull Request resolved: pytorch#89966
Approved by: https://github.com/zou3519
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Dec 22, 2022
1 parent d1123c9 commit 76a3869
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 1 deletion.
103 changes: 103 additions & 0 deletions functorch/experimental/_cond.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import StorageWeakRef

import torch.utils._pytree as pytree

from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
get_isolated_graphmodule,
get_proxy_slot,
ProxyTorchDispatchMode,
make_fx,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
Expand All @@ -19,6 +23,11 @@
from torch.utils._pytree import tree_flatten


@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str


"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
Expand Down Expand Up @@ -149,6 +158,100 @@ def cond_python_dispatcher(*args):
return cond(*args)


def _has_potential_branch_input_mutation(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e

input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True

return False

def _has_potential_branch_input_alias(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e

input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))

outs, _ = pytree.tree_flatten(gm(*fake_inputs))
for out in outs:
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
return True

return False



@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
"""
Functionalization implementation for torch.cond. Currently:
1. We don't allow any input mutation inside the branches
2. Our check for above condition is not exhaustive
"""
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)

functional_true_fn = functionalize(true_fn, remove=mode)
functional_false_fn = functionalize(false_fn, remove=mode)

with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
for branch in [functional_true_fn, functional_false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_mutation(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
for branch in [true_fn, false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_alias(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")

cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())

# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
Expand Down
2 changes: 1 addition & 1 deletion functorch/experimental/control_flow.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._map import map # noqa: F401
from ._cond import cond # noqa: F401
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401
173 changes: 173 additions & 0 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
from functorch.experimental.control_flow import UnsupportedAliasMutationException
from functorch.experimental import functionalize
from torch.fx.experimental.proxy_tensor import make_fx

from torch.testing._internal.common_utils import run_tests, TestCase
Expand Down Expand Up @@ -73,6 +75,177 @@ def f(x, pred, pred2):

self.assertEqual(result_false_true, torch.cos(x))

def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
y.add_(4)
return x.sin().max() + y.sum()

def false_fn(x):
return x.cos().min()

def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))

graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))

all_ops_in_true_branch = []
for node in graph_module.true_graph_0.graph.nodes:
if node.op == "call_function":
all_ops_in_true_branch.append(node.target)

self.assertFalse(any([op._schema.is_mutable for op in all_ops_in_true_branch]))

def test_cond_functionalized_nested(self):
def true_true_fn(x):
y = x.cos()
y.add_(4)
return x.sin().max() + y.sin().max()

def true_false_fn(x):
return x.cos().min()

def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])

def false_fn(x):
return x.sum()

def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))

graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))

gm_true_true_branch = graph_module.true_graph_0.true_graph_0

all_ops = []
for node in gm_true_true_branch.graph.nodes:
if node.op == "call_function":
all_ops.append(node.target)

self.assertFalse(any([op._schema.is_mutable for op in all_ops]))

def test_cond_functionalized_data_dependent_pred(self):
def true_fn(x):
return x.sin().sum()

def false_fn(x):
return x.cos().sum()

def f(x):
pred = x.nonzero().shape[0] == 1
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))

graph_module = make_fx(functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))

def test_cond_functionalized_input_mutation_on_true_branch(self):
def true_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.sin().sum()

def false_fn(x):
return x.cos().sum()

def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)

with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)

def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x):
return x.sin().sum()

def false_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.cos().sum()

def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(5, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)

with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)

def test_cond_functionalized_output_alias_input(self):
def true_fn(x):
return x

def false_fn(x):
view_x = x.view(x.shape)
return view_x

def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(5, 5),)
functional_f = functionalize(f)

with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
functional_f(*example_inputs)

with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
make_fx(functionalize(f))(*example_inputs)

def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x):
x.add_(4)
return x.sin().max()

def true_false_fn(x):
return x.cos().min()

def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])

def false_fn(x):
return x.sum()

def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])

example_inputs = (torch.ones(4, 5),)
functional_f = functionalize(f)
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
functional_f(*example_inputs)

with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
make_fx(functionalize(f))(*example_inputs)

def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y
Expand Down
6 changes: 6 additions & 0 deletions torch/_C/_functorch.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class CJvpInterpreterPtr:
def lift(self, Tensor) -> Tensor: ...
def prevFwdGradMode(self) -> bool: ...

class CFunctionalizeInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def functionalizeAddBackViews(self) -> bool: ...

class CVmapInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def key(self) -> TransformType: ...
Expand Down
17 changes: 17 additions & 0 deletions torch/_functorch/pyfunctorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RandomnessType,
CInterpreter,
CGradInterpreterPtr,
CFunctionalizeInterpreterPtr,
CVmapInterpreterPtr,
CJvpInterpreterPtr,
pop_dynamic_layer_stack,
Expand Down Expand Up @@ -172,6 +173,20 @@ def prev_fwd_grad_mode(self):
return self._cptr.prevFwdGradMode()


class FunctionalizeInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Functionalize
self._cdata = cdata
self._cptr = CFunctionalizeInterpreterPtr(cdata)

def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Functionalize]
return kernel(self, *args, **kwargs)

def functionalize_add_back_views(self):
return self._cptr.functionalizeAddBackViews()


def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
key = cinterpreter.key()
if key == TransformType.Grad:
Expand All @@ -180,6 +195,8 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
return VmapInterpreter(cinterpreter)
if key == TransformType.Jvp:
return JvpInterpreter(cinterpreter)
if key == TransformType.Functionalize:
return FunctionalizeInterpreter(cinterpreter)
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")


Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/functorch/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ void initFuncTorchBindings(PyObject* module) {
.def("level", &VmapInterpreterPtr::level)
.def("batchSize", &VmapInterpreterPtr::batchSize)
.def("randomness", &VmapInterpreterPtr::randomness);
py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
.def(py::init<const Interpreter*>())
.def("key", &FunctionalizeInterpreterPtr::key)
.def("level", &FunctionalizeInterpreterPtr::level)
.def(
"functionalizeAddBackViews",
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);
}

} // namespace impl
Expand Down

0 comments on commit 76a3869

Please sign in to comment.