Skip to content

Commit

Permalink
[PTSQ 2/4][Torch] Nested and successive hooks (openvinotoolkit#2310)
Browse files Browse the repository at this point in the history
Part 2/4 of SQ support openvinotoolkit#2279

### Changes

* Nested and successive hooks are supported

### Reason for changes

* To make it possible to insert nested and successive hooks to the torch
models

### Related tickets

124563

### Tests

* tests/torch/test_model_transformer.py checks sequential and nested
hooks are inserted correctly / checks temporary hooks are applied
correctly
  • Loading branch information
daniil-lyakhov committed Dec 15, 2023
1 parent 8cd810c commit 0c389c3
Show file tree
Hide file tree
Showing 14 changed files with 409 additions and 54 deletions.
15 changes: 6 additions & 9 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

import threading
import weakref
from collections import defaultdict
from collections import deque
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional
from typing import Callable, DefaultDict, List, Optional

import torch

Expand Down Expand Up @@ -91,8 +92,8 @@ class TracingContext:
def __init__(self):
self.graph = DynamicGraph()

self._post_hooks = {}
self._pre_hooks: Dict[PreHookId, List[Callable]] = {}
self._post_hooks: DefaultDict[OperationAddress, List[Callable]] = defaultdict(list)
self._pre_hooks: DefaultDict[PreHookId, List[Callable]] = defaultdict(list)
self._num_nested_hooks = 0

self._threading = CopySafeThreadingVars()
Expand Down Expand Up @@ -261,9 +262,7 @@ def pop_scope(self):

def register_pre_hooks(self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int):
pre_hook_id = PreHookId(op_address, input_port_id)
if pre_hook_id in self._pre_hooks:
raise KeyError("Pre hook for context {} is already registered".format(str(pre_hook_id)))
self._pre_hooks[pre_hook_id] = fn_list
self._pre_hooks[pre_hook_id].extend(fn_list)

def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInput) -> OperatorInput:
in_op = getattr(self, "in_operator", False)
Expand All @@ -282,9 +281,7 @@ def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInp
return op_inputs

def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress):
if op_address in self._post_hooks:
raise KeyError("Post hook for context {} is already registered".format(str(op_address)))
self._post_hooks[op_address] = fn_list
self._post_hooks[op_address].extend(fn_list)

def execute_post_hooks(self, op_address: OperationAddress, outputs):
in_op = getattr(self, "in_operator", False)
Expand Down
12 changes: 10 additions & 2 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,19 @@ class GraphBuilder:
def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]):
self.custom_forward_fn = custom_forward_fn

def build_dynamic_graph(
self,
model: torch.nn.Module,
context_to_use: Optional[TracingContext] = None,
as_eval: bool = False,
) -> DynamicGraph:
tracer = GraphTracer(self.custom_forward_fn)
return tracer.trace_graph(model, context_to_use, as_eval)

def build_graph(
self, model: torch.nn.Module, context_to_use: Optional[TracingContext] = None, as_eval: bool = False
) -> PTNNCFGraph:
tracer = GraphTracer(self.custom_forward_fn)
dynamic_graph = tracer.trace_graph(model, context_to_use, as_eval)
dynamic_graph = self.build_dynamic_graph(model, context_to_use, as_eval)
return GraphConverter.convert(dynamic_graph)


Expand Down
67 changes: 46 additions & 21 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from enum import IntEnum
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
Expand All @@ -28,6 +29,7 @@
from nncf.common.graph import NNCFNodeName
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.insertion_point_graph import InsertionPointGraph
Expand Down Expand Up @@ -122,6 +124,17 @@ class ExtraCompressionModuleType(Enum):
EXTERNAL_OP = 1


@dataclass
class PTGraphPair:
"""
Container for two dependent graph representation:
DynamicGraph and NNCFGraph built out of DynamicGraph.
"""

dynamic_graph: DynamicGraph
nncf_graph: NNCFGraph


class NNCFNetworkInterface(torch.nn.Module):
"""
The single object that is added to the original model object as an attribute to provide a namespace for
Expand Down Expand Up @@ -266,14 +279,14 @@ def __init__(
_orig_context.add_node_comparators(scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator())

if isinstance(model, NNCFNetwork):
self._original_dynamic_graph = model.nncf._original_dynamic_graph
self._original_graph = model.nncf._original_graph
self._original_graphs_pair = model.nncf._original_graphs_pair
else:
self._original_dynamic_graph = GraphTracer(_orig_graph_build_forward_fn).trace_graph(
original_dynamic_graph = GraphTracer(_orig_graph_build_forward_fn).trace_graph(
model, _orig_context, as_eval=True
)
self._original_graph = GraphConverter.convert(self._original_dynamic_graph)
self._compressed_graph: PTNNCFGraph = None
original_graph = GraphConverter.convert(original_dynamic_graph)
self._original_graphs_pair = PTGraphPair(dynamic_graph=original_dynamic_graph, nncf_graph=original_graph)
self._compressed_graphs_pair: PTGraphPair = None

self._compressed_context = TracingContext()

Expand Down Expand Up @@ -406,15 +419,15 @@ def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
raise RuntimeError("Unsupported insertion type: {}".format(point.insertion_type))

def get_graph(self) -> PTNNCFGraph:
if self._compressed_context.graph.get_nodes_count() == 0 or self._compressed_graph is None:
if self._compressed_context.graph.get_nodes_count() == 0 or self._compressed_graphs_pair.nncf_graph is None:
self.rebuild_graph()
return self._compressed_graph
return self._compressed_graphs_pair.nncf_graph

def get_dynamic_graph(self) -> DynamicGraph:
return self._compressed_context.graph

def get_original_graph(self) -> PTNNCFGraph:
return self._original_graph
return self._original_graphs_pair.nncf_graph

def get_tracing_context(self) -> TracingContext:
return self._compressed_context
Expand Down Expand Up @@ -467,7 +480,8 @@ def get_weighted_original_graph_nodes(self, nncf_module_names: List[str] = None)
module_name = nncf_module_scope[-1].calling_module_class_name
if module_name not in nncf_module_names:
continue
nodes_in_scope = self._original_graph.get_op_nodes_in_scope(nncf_module_scope)
nncf_graph: PTNNCFGraph = self._original_graphs_pair.nncf_graph
nodes_in_scope = nncf_graph.get_op_nodes_in_scope(nncf_module_scope)
for node in nodes_in_scope:
if node.metatype in OPERATORS_WITH_WEIGHTS_METATYPES:
retval.add(node)
Expand All @@ -482,7 +496,11 @@ def rebuild_graph(self, *input_args):
builder = GraphBuilder(dummy_forward_fn)

with training_mode_switcher(self._model_ref, is_training=False):
self._compressed_graph = builder.build_graph(self._model_ref, self._compressed_context)
compressed_traced_graph = builder.build_dynamic_graph(self._model_ref, self._compressed_context)
compressed_graph = GraphConverter.convert(compressed_traced_graph)
self._compressed_graphs_pair = PTGraphPair(
dynamic_graph=compressed_traced_graph, nncf_graph=compressed_graph
)

def is_scope_in_nncf_module_scope(self, scope: Scope) -> bool:
norm_nncf_scopes = []
Expand Down Expand Up @@ -581,7 +599,7 @@ def do_dummy_forward(self, force_eval: bool = False):
if train_mode:
self._model_ref.train()

def get_insertion_point_graph(self) -> InsertionPointGraph:
def get_original_insertion_point_graph(self) -> InsertionPointGraph:
# Set up a pre- and post-hooks on almost every op in PyTorch
nncf_graph = self.get_original_graph()
pre_hooks: List[PreHookInsertionPoint] = []
Expand Down Expand Up @@ -613,7 +631,7 @@ def get_insertion_point_graph(self) -> InsertionPointGraph:
weighted_node_names = [weighted_node.node_name for weighted_node in weighted_nodes]

ip_graph = InsertionPointGraph(
self._original_graph,
self._original_graphs_pair.nncf_graph,
weight_modifiable_node_names=weighted_node_names,
allowed_pre_hook_insertion_points=pre_hooks,
allowed_post_hook_insertion_points=post_hooks,
Expand All @@ -625,18 +643,18 @@ def get_module_by_scope(self, scope: Scope) -> Optional[torch.nn.Module]:
return get_module_by_scope(curr_module, scope)

def get_containing_module(self, node_name: NNCFNodeName) -> torch.nn.Module:
if self._compressed_graph is not None:
if self._compressed_graphs_pair is not None:
try:
scope = self._compressed_graph.get_scope_by_node_name(node_name)
scope = self._compressed_graphs_pair.nncf_graph.get_scope_by_node_name(node_name)
except RuntimeError:
nncf_logger.debug(
f"Node {node_name} not found in compressed graph when trying to determine "
f"the containing module, trying the original graph to see if the node was "
f"present there during graph building"
)
scope = self._original_graph.get_scope_by_node_name(node_name)
scope = self._original_graphs_pair.nncf_graph.get_scope_by_node_name(node_name)
else:
scope = self._original_graph.get_scope_by_node_name(node_name)
scope = self._original_graphs_pair.nncf_graph.get_scope_by_node_name(node_name)
return self.get_module_by_scope(scope)

def get_flops_per_module(self) -> Dict[NNCFNodeName, int]:
Expand All @@ -650,7 +668,7 @@ def get_hook(name):
return functools.partial(compute_FLOPs_hook, dict_to_save=flops_count_dict, module_node_name=name)

hook_list = []
for nncf_node in self._original_graph.get_all_nodes():
for nncf_node in self._original_graphs_pair.nncf_graph.get_all_nodes():
node_module = self.get_containing_module(nncf_node.node_name)
hook_list.append(node_module.register_forward_hook(get_hook(nncf_node.node_name)))
model.nncf.do_dummy_forward(force_eval=True)
Expand Down Expand Up @@ -716,13 +734,20 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable)
return result

def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]:
# The IDs of corresponding nodes of the original dynamic graph and original NNCF graph
# must be equal for this to work.
"""
Returns map of NNCFGraph node names vs DynamicGraph operation addresses.
:return: NNCFGraph node names vs DynamicGraph operation addresses map.
"""
graph_pair = self._compressed_graphs_pair
if graph_pair is None:
graph_pair = self._original_graphs_pair

retval = {}
for node in self._original_dynamic_graph.get_all_nodes():
for node in graph_pair.dynamic_graph.get_all_nodes():
node_id = node.node_id
op_address = node.op_exec_context.op_address
nncf_node = self._original_graph.get_node_by_id(node_id)
nncf_node = graph_pair.nncf_graph.get_node_by_id(node_id)
retval[nncf_node.node_name] = op_address
return retval

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/quantization/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __init__(
def generate_setup(self) -> SingleConfigQuantizerSetup:
quantizable_module_nodes = self.get_quantizable_module_nodes()

insertion_point_graph = self._target_model.nncf.get_insertion_point_graph()
insertion_point_graph = self._target_model.nncf.get_original_insertion_point_graph()
if self._debug_interface:
self._debug_interface.visualize_insertion_point_graph(insertion_point_graph)
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationSolver
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def device(self) -> torch.device:
return self._tensor.device

def is_empty(self) -> bool:
return self.tensor.size == 0
return self.tensor.numel() == 0
8 changes: 4 additions & 4 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class BadStatContainer:

class TemplateTestStatisticCollector:
@abstractmethod
def get_nncf_tensor_cls(self):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
pass

@abstractmethod
Expand Down Expand Up @@ -366,10 +366,10 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([]))}, [(hash(reducer), [input_name])]
)

stats = collector.get_statistics()
Expand All @@ -385,7 +385,7 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
assert aggregator._collected_samples == 2
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] == self.get_nncf_tensor_cls()([100])
assert stats["A"] == self.get_nncf_tensor([100])
return

assert len(aggregator._container) == 0
Expand Down
12 changes: 7 additions & 5 deletions tests/common/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class BCStatsCollectors(Enum):


class TemplateTestStatisticsAggregator:
@staticmethod
@abstractmethod
def get_min_max_algo_backend_cls(self) -> Type[MinMaxAlgoBackend]:
def get_min_max_algo_backend_cls() -> Type[MinMaxAlgoBackend]:
pass

@abstractmethod
Expand All @@ -73,8 +74,9 @@ def get_statistics_aggregator(self, dataset):
def get_dataset(self, samples):
pass

@staticmethod
@abstractmethod
def get_target_point(self, target_type: TargetType) -> TargetPoint:
def get_target_point(target_type: TargetType) -> TargetPoint:
pass

@abstractmethod
Expand Down Expand Up @@ -387,7 +389,6 @@ def test_statistics_aggregator_min_max(
inplace_statistics,
is_backend_support_custom_estimators,
):
inplace_statistics = False
model = self.get_backend_model(dataset_samples)
quantizer_config = QuantizerConfig(
mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel
Expand Down Expand Up @@ -631,10 +632,11 @@ def filter_func(point):
assert ref.shape == val.shape
assert np.allclose(val, ref)

@classmethod
def create_statistics_point(
self, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
cls, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
):
algo_backend = self.get_min_max_algo_backend_cls()
algo_backend = cls.get_min_max_algo_backend_cls()
nncf_graph = NNCFGraphFactory.create(model)
tensor_collector = algo_backend.get_statistic_collector(
range_estimator,
Expand Down
6 changes: 4 additions & 2 deletions tests/onnx/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
def get_min_max_algo_backend_cls(self) -> Type[ONNXMinMaxAlgoBackend]:
@staticmethod
def get_min_max_algo_backend_cls() -> Type[ONNXMinMaxAlgoBackend]:
return ONNXMinMaxAlgoBackend

def get_bias_correction_algo_backend_cls(self) -> Type[ONNXBiasCorrectionAlgoBackend]:
Expand Down Expand Up @@ -65,7 +66,8 @@ def transform_fn(data_item):

return Dataset(samples, transform_fn)

def get_target_point(self, target_type: TargetType):
@staticmethod
def get_target_point(target_type: TargetType):
target_node_name = IDENTITY_NODE_NAME
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/native/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from typing import Type

import numpy as np
import pytest

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
Expand All @@ -26,8 +28,8 @@


class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return OVNNCFTensor
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
return OVNNCFTensor(value)

@pytest.fixture
def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]:
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/native/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def get_StatisticAgregatorTestModel(input_shape, kernel):


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
def get_min_max_algo_backend_cls(self) -> Type[OVMinMaxAlgoBackend]:
@staticmethod
def get_min_max_algo_backend_cls() -> Type[OVMinMaxAlgoBackend]:
return OVMinMaxAlgoBackend

def get_bias_correction_algo_backend_cls(self) -> Type[OVBiasCorrectionAlgoBackend]:
Expand Down Expand Up @@ -82,7 +83,8 @@ def get_target_point_cls(self):
def get_dataset(self, samples):
return Dataset(samples, lambda data: {INPUT_NAME: data})

def get_target_point(self, target_type: TargetType) -> TargetPoint:
@staticmethod
def get_target_point(target_type: TargetType) -> TargetPoint:
target_node_name = INPUT_NAME
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
Expand Down
Loading

0 comments on commit 0c389c3

Please sign in to comment.