Skip to content

Commit

Permalink
[TORCH] Unwrap and wrap torch.return_type before and after posthooks (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#2290)

### Changes

* Post hooks inputs are being unwrapped from `torch.return_type` types
to `torch.tensor` type.
* Pre hooks outputs are being wrapped from `torch.tensor` to
`torch.return_type` in case torch input were wrapped in the first place.

### Reason for changes

To enable post hook insertion after torch operations which return
`torch.tensor_type` values instead of `torch.tensor` values.
 
### Related tickets



### Tests

* tests/torch/test_nncf_network.py is updated
* tests/torch/test_return_types.py is introduced
  • Loading branch information
daniil-lyakhov committed Nov 30, 2023
1 parent 4c4aeac commit db786a8
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 2 deletions.
8 changes: 6 additions & 2 deletions nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from nncf.torch.dynamic_graph.trace_tensor import trace_tensors
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.layers import ITERATION_MODULES
from nncf.torch.return_types import maybe_unwrap_from_torch_return_type
from nncf.torch.return_types import maybe_wrap_to_torch_return_type

_IGNORED_SCOPES = []

Expand Down Expand Up @@ -188,8 +190,10 @@ def _execute_op(
if is_debug() and node is not None:
ctx.register_node_call(node)

result = trace_tensors(result, node, ctx)
result = ctx.execute_post_hooks(op_address, result)
unwrapped_result = maybe_unwrap_from_torch_return_type(result)
unwrapped_result = trace_tensors(unwrapped_result, node, ctx)
unwrapped_result = ctx.execute_post_hooks(op_address, unwrapped_result)
result = maybe_wrap_to_torch_return_type(unwrapped_result, result)
return result


Expand Down
59 changes: 59 additions & 0 deletions nncf/torch/return_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Optional, Tuple, Type, Union

import torch


def __get_supported_torch_return_types() -> Tuple[Type[tuple], ...]:
"""
Collects types from torch.return_type which can be wrapped/unwrapped by NNCF.
NNCF can wrap/unwrap only return types that have two attributes, one of them
should be the `values` attribute.
:return: List of types from torch.return_type which can be wrapped/unwrapped by NNCF.
"""
retval = [t for _, t in inspect.getmembers(torch.return_types) if inspect.isclass(t) and hasattr(t, "values")]
return tuple(t for t in retval if t.n_fields == 2)


_TORCH_RETURN_TYPES = __get_supported_torch_return_types()


def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor:
"""
Attempts to unwrap the tensor value from one of torch.return_types instances
in case torch operation output is wrapped by a torch return_type.
:param tensor: Torch tensor or torch return type instance to unwrap values from.
:return: Unwrapped torch tensor.
"""
if isinstance(tensor, _TORCH_RETURN_TYPES):
return tensor.values
return tensor


def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optional[Union[tuple, torch.Tensor]]) -> Any:
"""
Wraps tensor to wrapped_input wrapper in case wrapped_input is wrapped by a torch.return_value container.
:param tensor: Torch tensor to wrap.
:param wrapped_tensor: Instance of the tensor before it was unwrapped.
:return: Wrapped tensor in case wrapped_input is wrapped by a torch.return_value container else the tensor.
"""

if isinstance(wrapped_input, _TORCH_RETURN_TYPES):
# We assume that return_type has only two attributes, the first one is `value`.
# This assumption is checked by `test_unwrap_wrap_torch_return_type`.
return wrapped_input.__class__((tensor, wrapped_input[1]))
return tensor
27 changes: 27 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,30 @@ def test_access_to_input_info():
input_info = ExampleInputInfo.from_example_input(example_input)
nncf_model = NNCFNetwork(model, input_info)
nncf_model.nncf.input_infos


class ModelWithMax(torch.nn.Module):
INPUT_SIZE = [1, 1, 32, 32]

def __init__(self):
super().__init__()

def forward(self, x):
x = torch.max(x, dim=-1, keepdim=True)
assert isinstance(x, torch.return_types.max)
return x.values


def test_torch_return_types_unwrapped_for_post_hook():
model = ModelWithMax()
nncf_model = NNCFNetwork(model, FillerInputInfo([FillerInputElement(SimplestModel.INPUT_SIZE)]))
node_to_op_address_mapping = nncf_model.nncf.get_node_to_op_address_mapping()
insertion_point = PTInsertionPoint(
TargetType.OPERATOR_POST_HOOK, node_to_op_address_mapping["ModelWithMax/max_0"], 0
)

def fn_to_check_input_type(input):
assert isinstance(input, torch.Tensor)

nncf_model.nncf.insert_at_point(insertion_point, [fn_to_check_input_type])
nncf_model.nncf.rebuild_graph()
37 changes: 37 additions & 0 deletions tests/torch/test_return_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from nncf.torch.return_types import _TORCH_RETURN_TYPES
from nncf.torch.return_types import maybe_unwrap_from_torch_return_type
from nncf.torch.return_types import maybe_wrap_to_torch_return_type


@pytest.mark.parametrize("return_type", _TORCH_RETURN_TYPES)
def test_unwrap_wrap_torch_return_type(return_type):
wrapped_tensor = return_type((torch.tensor(0), torch.tensor(1)))
assert wrapped_tensor.values == torch.tensor(0)
unwrapped_tensor = maybe_unwrap_from_torch_return_type(wrapped_tensor)
assert unwrapped_tensor == torch.tensor(0)

updated_wrapped_tensor = maybe_wrap_to_torch_return_type(unwrapped_tensor, wrapped_tensor)
assert updated_wrapped_tensor == wrapped_tensor


@pytest.mark.parametrize(
"input_", [torch.tensor(0), [torch.tensor(0), torch.tensor(1)], (torch.tensor(0), torch.tensor(1))]
)
def test_wrap_unwrap_do_nothing_to_tensor(input_):
wrapped_input = maybe_unwrap_from_torch_return_type(input_)
assert wrapped_input is input_
unwrapped_input = maybe_wrap_to_torch_return_type(input_, wrapped_input)
assert unwrapped_input is input_

0 comments on commit db786a8

Please sign in to comment.