From 76a3869fc66814e855818991e41729f2e32952b2 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Wed, 21 Dec 2022 13:50:36 -0800 Subject: [PATCH] Support functionalization on torch.cond (#89966) This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following: - Output of each branch aliasing input - In-place mutation on inputs given to each branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966 Approved by: https://github.com/zou3519 --- functorch/experimental/_cond.py | 103 +++++++++++++++ functorch/experimental/control_flow.py | 2 +- test/functorch/test_control_flow.py | 173 +++++++++++++++++++++++++ torch/_C/_functorch.pyi | 6 + torch/_functorch/pyfunctorch.py | 17 +++ torch/csrc/functorch/init.cpp | 7 + 6 files changed, 307 insertions(+), 1 deletion(-) diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py index a3c1936560439..d435a86ab555f 100644 --- a/functorch/experimental/_cond.py +++ b/functorch/experimental/_cond.py @@ -1,14 +1,18 @@ +from dataclasses import dataclass import torch +from torch.multiprocessing.reductions import StorageWeakRef import torch.utils._pytree as pytree from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard +from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize from torch._ops import PyOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( get_isolated_graphmodule, get_proxy_slot, ProxyTorchDispatchMode, + make_fx, track_tensor_tree, ) from torch.fx.passes.shape_prop import _extract_tensor_metadata @@ -19,6 +23,11 @@ from torch.utils._pytree import tree_flatten +@dataclass +class UnsupportedAliasMutationException(RuntimeError): + reason: str + + """ We're going to define a `cond` operation. In order to do this, we need implementations for each of the dispatch keys. @@ -149,6 +158,100 @@ def cond_python_dispatcher(*args): return cond(*args) +def _has_potential_branch_input_mutation(branch, fake_inputs): + """ + Dispatch-trace the branch with fake inputs and check if + producing graph has mutable op on the input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch)(*fake_inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond is + # functionalized + return True + except Exception as e: + raise e + + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: + for arg in node.args: + if arg in input_nodes: + return True + + return False + +def _has_potential_branch_input_alias(branch, fake_inputs): + """ + Dispatch-trace the branch with fake inputs and check if + producing graph has output aliasing the branch input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch)(*fake_inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond is + # functionalized + return True + except Exception as e: + raise e + + input_storages = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) + + outs, _ = pytree.tree_flatten(gm(*fake_inputs)) + for out in outs: + if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages: + return True + + return False + + + +@cond.py_impl(torch._C._functorch.TransformType.Functionalize) +def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs): + """ + Functionalization implementation for torch.cond. Currently: + 1. We don't allow any input mutation inside the branches + 2. Our check for above condition is not exhaustive + """ + reapply_views = interpreter.functionalize_add_back_views() + mode = 'mutations_and_views' if reapply_views else 'mutations' + # At this point, we will see functionalized tensors, so need to unwrap them first + unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) + unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) + + functional_true_fn = functionalize(true_fn, remove=mode) + functional_false_fn = functionalize(false_fn, remove=mode) + + with interpreter.lower(): + fake_tensor_mode = FakeTensorMode() + with fake_tensor_mode as ft_mode: + for branch in [functional_true_fn, functional_false_fn]: + def convert(x): + return ft_mode.fake_tensor_converter(ft_mode, x) + fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs) + if _has_potential_branch_input_mutation(branch, fake_inputs): + raise UnsupportedAliasMutationException("One of torch.cond branch " + "might be modifying the input!") + for branch in [true_fn, false_fn]: + def convert(x): + return ft_mode.fake_tensor_converter(ft_mode, x) + fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs) + if _has_potential_branch_input_alias(branch, fake_inputs): + raise UnsupportedAliasMutationException("One of torch.cond branch " + "might be aliasing the input!") + + cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs) + return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level()) + # TODO(voz): Make this automatic for keys, this is very ugly atm cond.fallthrough(DispatchKey.PythonTLSSnapshot) cond.fallthrough(DispatchKey.ADInplaceOrView) diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index fb235b10cc460..5d42598c757aa 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,2 +1,2 @@ from ._map import map # noqa: F401 -from ._cond import cond # noqa: F401 +from ._cond import cond, UnsupportedAliasMutationException # noqa: F401 diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index afa10a3de5ee5..2c984fdc88622 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -2,6 +2,8 @@ import torch from functorch.experimental import control_flow from functorch.experimental.control_flow import cond +from functorch.experimental.control_flow import UnsupportedAliasMutationException +from functorch.experimental import functionalize from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import run_tests, TestCase @@ -73,6 +75,177 @@ def f(x, pred, pred2): self.assertEqual(result_false_true, torch.cos(x)) + def test_cond_functionalized(self): + def true_fn(x): + y = x.sin() + y.add_(4) + return x.sin().max() + y.sum() + + def false_fn(x): + return x.cos().min() + + def f(x): + pred = x.shape[0] == 1 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(4, 5),) + functional_f = functionalize(f) + self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) + + graph_module = make_fx(functionalize(f))(*example_inputs) + self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) + + all_ops_in_true_branch = [] + for node in graph_module.true_graph_0.graph.nodes: + if node.op == "call_function": + all_ops_in_true_branch.append(node.target) + + self.assertFalse(any([op._schema.is_mutable for op in all_ops_in_true_branch])) + + def test_cond_functionalized_nested(self): + def true_true_fn(x): + y = x.cos() + y.add_(4) + return x.sin().max() + y.sin().max() + + def true_false_fn(x): + return x.cos().min() + + def true_fn(x): + pred = x.shape[0] == 1 + return cond(pred, true_true_fn, true_false_fn, [x]) + + def false_fn(x): + return x.sum() + + def f(x): + pred = x.shape[0] == 1 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(4, 5),) + functional_f = functionalize(f) + self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) + + graph_module = make_fx(functionalize(f))(*example_inputs) + self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) + + gm_true_true_branch = graph_module.true_graph_0.true_graph_0 + + all_ops = [] + for node in gm_true_true_branch.graph.nodes: + if node.op == "call_function": + all_ops.append(node.target) + + self.assertFalse(any([op._schema.is_mutable for op in all_ops])) + + def test_cond_functionalized_data_dependent_pred(self): + def true_fn(x): + return x.sin().sum() + + def false_fn(x): + return x.cos().sum() + + def f(x): + pred = x.nonzero().shape[0] == 1 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(4, 5),) + functional_f = functionalize(f) + self.assertEqual(functional_f(*example_inputs), f(*example_inputs)) + + graph_module = make_fx(functionalize(f))(*example_inputs) + self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) + + def test_cond_functionalized_input_mutation_on_true_branch(self): + def true_fn(x): + view_x = x.view(x.shape) + view_x.add_(1) + return view_x.sin().sum() + + def false_fn(x): + return x.cos().sum() + + def f(x): + pred = x.shape[0] == 4 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(4, 5),) + functional_f = functionalize(f) + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + functional_f(*example_inputs) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + make_fx(functionalize(f))(*example_inputs) + + def test_cond_functionalized_input_mutation_on_false_branch(self): + def true_fn(x): + return x.sin().sum() + + def false_fn(x): + view_x = x.view(x.shape) + view_x.add_(1) + return view_x.cos().sum() + + def f(x): + pred = x.shape[0] == 4 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(5, 5),) + functional_f = functionalize(f) + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + functional_f(*example_inputs) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + make_fx(functionalize(f))(*example_inputs) + + def test_cond_functionalized_output_alias_input(self): + def true_fn(x): + return x + + def false_fn(x): + view_x = x.view(x.shape) + return view_x + + def f(x): + pred = x.shape[0] == 4 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(5, 5),) + functional_f = functionalize(f) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"): + functional_f(*example_inputs) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"): + make_fx(functionalize(f))(*example_inputs) + + def test_cond_functionalized_nested_input_mutation(self): + def true_true_fn(x): + x.add_(4) + return x.sin().max() + + def true_false_fn(x): + return x.cos().min() + + def true_fn(x): + pred = x.shape[0] == 1 + return cond(pred, true_true_fn, true_false_fn, [x]) + + def false_fn(x): + return x.sum() + + def f(x): + pred = x.shape[0] == 1 + return cond(pred, true_fn, false_fn, [x]) + + example_inputs = (torch.ones(4, 5),) + functional_f = functionalize(f) + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + functional_f(*example_inputs) + + with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"): + make_fx(functionalize(f))(*example_inputs) + def test_cond_nested_traced_other_inputs(self): def true_nested(y): return y * y diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index d0916e9f035b0..9ad8d73870a74 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -48,6 +48,12 @@ class CJvpInterpreterPtr: def lift(self, Tensor) -> Tensor: ... def prevFwdGradMode(self) -> bool: ... +class CFunctionalizeInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def functionalizeAddBackViews(self) -> bool: ... + class CVmapInterpreterPtr: def __init__(self, interpreter: CInterpreter): ... def key(self) -> TransformType: ... diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index d742cd1b0392b..c005a20f1b53b 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -8,6 +8,7 @@ RandomnessType, CInterpreter, CGradInterpreterPtr, + CFunctionalizeInterpreterPtr, CVmapInterpreterPtr, CJvpInterpreterPtr, pop_dynamic_layer_stack, @@ -172,6 +173,20 @@ def prev_fwd_grad_mode(self): return self._cptr.prevFwdGradMode() +class FunctionalizeInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Functionalize + self._cdata = cdata + self._cptr = CFunctionalizeInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Functionalize] + return kernel(self, *args, **kwargs) + + def functionalize_add_back_views(self): + return self._cptr.functionalizeAddBackViews() + + def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: key = cinterpreter.key() if key == TransformType.Grad: @@ -180,6 +195,8 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: return VmapInterpreter(cinterpreter) if key == TransformType.Jvp: return JvpInterpreter(cinterpreter) + if key == TransformType.Functionalize: + return FunctionalizeInterpreter(cinterpreter) raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 2380a28ca2367..b97bad6c6de2d 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -519,6 +519,13 @@ void initFuncTorchBindings(PyObject* module) { .def("level", &VmapInterpreterPtr::level) .def("batchSize", &VmapInterpreterPtr::batchSize) .def("randomness", &VmapInterpreterPtr::randomness); + py::class_(m, "CFunctionalizeInterpreterPtr") + .def(py::init()) + .def("key", &FunctionalizeInterpreterPtr::key) + .def("level", &FunctionalizeInterpreterPtr::level) + .def( + "functionalizeAddBackViews", + &FunctionalizeInterpreterPtr::functionalizeAddBackViews); } } // namespace impl