diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index ee0a74355bb..144f5d68238 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -28,6 +28,8 @@ from nncf.torch.dynamic_graph.trace_tensor import trace_tensors from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import ITERATION_MODULES +from nncf.torch.return_types import maybe_unwrap_from_torch_return_type +from nncf.torch.return_types import maybe_wrap_to_torch_return_type _IGNORED_SCOPES = [] @@ -188,8 +190,10 @@ def _execute_op( if is_debug() and node is not None: ctx.register_node_call(node) - result = trace_tensors(result, node, ctx) - result = ctx.execute_post_hooks(op_address, result) + unwrapped_result = maybe_unwrap_from_torch_return_type(result) + unwrapped_result = trace_tensors(unwrapped_result, node, ctx) + unwrapped_result = ctx.execute_post_hooks(op_address, unwrapped_result) + result = maybe_wrap_to_torch_return_type(unwrapped_result, result) return result diff --git a/nncf/torch/return_types.py b/nncf/torch/return_types.py new file mode 100644 index 00000000000..1b8c8e4fd75 --- /dev/null +++ b/nncf/torch/return_types.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Optional, Tuple, Type, Union + +import torch + + +def __get_supported_torch_return_types() -> Tuple[Type[tuple], ...]: + """ + Collects types from torch.return_type which can be wrapped/unwrapped by NNCF. + NNCF can wrap/unwrap only return types that have two attributes, one of them + should be the `values` attribute. + + :return: List of types from torch.return_type which can be wrapped/unwrapped by NNCF. + """ + retval = [t for _, t in inspect.getmembers(torch.return_types) if inspect.isclass(t) and hasattr(t, "values")] + return tuple(t for t in retval if t.n_fields == 2) + + +_TORCH_RETURN_TYPES = __get_supported_torch_return_types() + + +def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor: + """ + Attempts to unwrap the tensor value from one of torch.return_types instances + in case torch operation output is wrapped by a torch return_type. + + :param tensor: Torch tensor or torch return type instance to unwrap values from. + :return: Unwrapped torch tensor. + """ + if isinstance(tensor, _TORCH_RETURN_TYPES): + return tensor.values + return tensor + + +def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optional[Union[tuple, torch.Tensor]]) -> Any: + """ + Wraps tensor to wrapped_input wrapper in case wrapped_input is wrapped by a torch.return_value container. + + :param tensor: Torch tensor to wrap. + :param wrapped_tensor: Instance of the tensor before it was unwrapped. + :return: Wrapped tensor in case wrapped_input is wrapped by a torch.return_value container else the tensor. + """ + + if isinstance(wrapped_input, _TORCH_RETURN_TYPES): + # We assume that return_type has only two attributes, the first one is `value`. + # This assumption is checked by `test_unwrap_wrap_torch_return_type`. + return wrapped_input.__class__((tensor, wrapped_input[1])) + return tensor diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 5b8f5b63a01..eb3f5fd32d9 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -847,3 +847,30 @@ def test_access_to_input_info(): input_info = ExampleInputInfo.from_example_input(example_input) nncf_model = NNCFNetwork(model, input_info) nncf_model.nncf.input_infos + + +class ModelWithMax(torch.nn.Module): + INPUT_SIZE = [1, 1, 32, 32] + + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.max(x, dim=-1, keepdim=True) + assert isinstance(x, torch.return_types.max) + return x.values + + +def test_torch_return_types_unwrapped_for_post_hook(): + model = ModelWithMax() + nncf_model = NNCFNetwork(model, FillerInputInfo([FillerInputElement(SimplestModel.INPUT_SIZE)])) + node_to_op_address_mapping = nncf_model.nncf.get_node_to_op_address_mapping() + insertion_point = PTInsertionPoint( + TargetType.OPERATOR_POST_HOOK, node_to_op_address_mapping["ModelWithMax/max_0"], 0 + ) + + def fn_to_check_input_type(input): + assert isinstance(input, torch.Tensor) + + nncf_model.nncf.insert_at_point(insertion_point, [fn_to_check_input_type]) + nncf_model.nncf.rebuild_graph() diff --git a/tests/torch/test_return_types.py b/tests/torch/test_return_types.py new file mode 100644 index 00000000000..c2bf812f0e0 --- /dev/null +++ b/tests/torch/test_return_types.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from nncf.torch.return_types import _TORCH_RETURN_TYPES +from nncf.torch.return_types import maybe_unwrap_from_torch_return_type +from nncf.torch.return_types import maybe_wrap_to_torch_return_type + + +@pytest.mark.parametrize("return_type", _TORCH_RETURN_TYPES) +def test_unwrap_wrap_torch_return_type(return_type): + wrapped_tensor = return_type((torch.tensor(0), torch.tensor(1))) + assert wrapped_tensor.values == torch.tensor(0) + unwrapped_tensor = maybe_unwrap_from_torch_return_type(wrapped_tensor) + assert unwrapped_tensor == torch.tensor(0) + + updated_wrapped_tensor = maybe_wrap_to_torch_return_type(unwrapped_tensor, wrapped_tensor) + assert updated_wrapped_tensor == wrapped_tensor + + +@pytest.mark.parametrize( + "input_", [torch.tensor(0), [torch.tensor(0), torch.tensor(1)], (torch.tensor(0), torch.tensor(1))] +) +def test_wrap_unwrap_do_nothing_to_tensor(input_): + wrapped_input = maybe_unwrap_from_torch_return_type(input_) + assert wrapped_input is input_ + unwrapped_input = maybe_wrap_to_torch_return_type(input_, wrapped_input) + assert unwrapped_input is input_