Skip to content

Commit

Permalink
Use a graph to find the output port id for the activation quantizatio…
Browse files Browse the repository at this point in the history
…n target point (#2802)

### Changes

Before the changes in the PR, it was possible to insert the quantize
operation after some layer only on its first output (`output port_id =
0`). The PR adds changes that allow the insertion of the quantize
operation for operations with multiple outputs (they have different
output port IDs), but only when one is used.

### Reason for changes

The selected FQ operations are not inserted unless changes are made.

![MatMul_problem](https://github.com/openvinotoolkit/nncf/assets/77268007/45137cd3-9920-41a6-83a6-28b15848835f)

### Related tickets

- 146088

### Tests

pre-commit scope
  • Loading branch information
andrey-churkin committed Jul 18, 2024
1 parent d113c2a commit 380d2bb
Show file tree
Hide file tree
Showing 14 changed files with 94 additions and 40 deletions.
36 changes: 34 additions & 2 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,52 @@ def get_input_edges(self, node: NNCFNode) -> List[NNCFGraphEdge]:
edges.extend(self._get_edges(from_node, node))
return sorted(edges, key=lambda x: x.input_port_id)

def get_input_edge_by_port_id(self, node: NNCFNode, port_id: int) -> NNCFGraphEdge:
"""
Returns the input edge for a given node, where edge.input_port_id == port_id is True.
:param node: The node for which to retrieve the input edge.
:param port_id: The ID of the input port to filter the edges.
:return: An input edge connected to the specified input port ID of the
given node.
"""
edges = [e for e in self.get_input_edges(node) if e.input_port_id == port_id]
if len(edges) == 0:
raise nncf.ValidationError(
f"Node {node.node_name} does not contain input edge connected to {port_id} port ID."
)

if len(edges) > 1:
raise nncf.InternalError(
"Unsupported graph. More than one edge was found for a given node by the specified input port ID."
)
return edges[0]

def get_output_edges(self, node: NNCFNode) -> List[NNCFGraphEdge]:
"""
Returns edges of output tensors sorted by output port ID.
:param node: Producer node.
:return: List of output edges for the node sorted by output port ID.
:return: List of output edges for the node sorted by output port ID.
"""

output_nodes = self.get_next_nodes(node)
edges = []
for to_node in output_nodes:
edges.extend(self._get_edges(node, to_node))
return sorted(edges, key=lambda x: x.output_port_id)

def get_output_edges_by_port_id(self, node: NNCFNode, port_id: int) -> List[NNCFGraphEdge]:
"""
Returns a list of output edges for a given node, filtered by the specified
output port ID (edge.output_port_id == port_id).
:param node: The node for which to retrieve the output edges.
:param port_id: The ID of the output port to filter the edges.
:return: A list of the output edges connected to the specified output port ID
of the given node.
"""
return [e for e in self.get_output_edges(node) if e.output_port_id == port_id]

def _get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> List[NNCFGraphEdge]:
edges = []
edge = self.get_edge(from_node, to_node)
Expand Down
15 changes: 8 additions & 7 deletions nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,8 @@ def is_port_quantized(node: NNCFNode, nncf_graph: NNCFGraph, port_id: int) -> bo
:param port_id: Input port id of a node.
:return: True if a port_id is quantized - have ONNXDequantizeLinearMetatype as a parent node.
"""
input_nodes = [edge.from_node for edge in nncf_graph.get_input_edges(node)]
if len(input_nodes) > port_id:
weight_node = input_nodes[port_id]
return weight_node.metatype == ONNXDequantizeLinearMetatype
return False
edge = nncf_graph.get_input_edge_by_port_id(node, port_id)
return edge.from_node.metatype == ONNXDequantizeLinearMetatype


def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
Expand Down Expand Up @@ -172,9 +169,13 @@ def _get_activation_tensor_shape(
:return: None, if there is no shape info, otherwise - tensor shape.
"""
if target_point.type == TargetType.PRE_LAYER_OPERATION:
shape = nncf_graph.get_input_edges(node)[target_point.port_id].tensor_shape
edge = nncf_graph.get_input_edge_by_port_id(node, target_point.port_id)
shape = edge.tensor_shape
elif target_point.type == TargetType.POST_LAYER_OPERATION:
shape = nncf_graph.get_output_edges(node)[target_point.port_id].tensor_shape
# NOTE: Assumes that all output edges for the `node` with `output_port_id`
# equal to `target_point.port_id` should have the same `tensor_shape` value.
edges = nncf_graph.get_output_edges_by_port_id(node, target_point.port_id)
shape = edges[0].tensor_shape
else:
raise NotImplementedError(f"Unsupported target point type {target_point.type}.")
if not shape: # ONNX model can not have a shape of a edge, even after shape inference.
Expand Down
4 changes: 2 additions & 2 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def get_node_with_bias_value(add_node: NNCFNode, nncf_graph: NNCFGraph) -> Optio
const_port_ids = add_node.layer_attributes.get_const_port_ids()
assert len(const_port_ids) == 1
bias_port_id = const_port_ids[0]
bias_constant = nncf_graph.get_input_edges(add_node)[bias_port_id].from_node
bias_constant = nncf_graph.get_input_edge_by_port_id(add_node, bias_port_id).from_node

if bias_constant.metatype == OVConvertMetatype:
bias_constant = nncf_graph.get_input_edges(bias_constant)[0].from_node
bias_constant = nncf_graph.get_input_edge_by_port_id(bias_constant, 0).from_node

return bias_constant if bias_constant.metatype == OVConstantMetatype else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return False
const_port_ids = node.layer_attributes.get_const_port_ids()
assert len(const_port_ids) == 1
weight_node = nncf_graph.get_input_edges(node)[const_port_ids[0]].from_node
weight_node = nncf_graph.get_input_edge_by_port_id(node, const_port_ids[0]).from_node
return weight_node.metatype in FAKE_QUANTIZE_OPERATIONS

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return False
const_port_ids = node.layer_attributes.get_const_port_ids()
assert len(const_port_ids) == 1
weight_node = nncf_graph.get_input_edges(node)[const_port_ids[0]].from_node
weight_node = nncf_graph.get_input_edge_by_port_id(node, const_port_ids[0]).from_node
return weight_node.metatype in FAKE_QUANTIZE_OPERATIONS

@staticmethod
Expand Down
37 changes: 27 additions & 10 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,21 +646,24 @@ def _add_weight_quantization_target_point(
Adds weight quantization target point to the set of existing points.
:param quantization_point: SingleConfigQuantizationPoint for the needed layer.
:param model: Model in the original framework.
:param nncf_graph: The built NNCFGraph of the model.
"""
weight_quantization_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph)
for weight_quantization_target_point in weight_quantization_target_points:
self._quantization_target_points_to_qconfig[weight_quantization_target_point] = quantization_point.qconfig

def _add_activation_quantization_target_point(self, quantization_point: SingleConfigQuantizationPoint) -> None:
def _add_activation_quantization_target_point(
self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph
) -> None:
"""
Adds activation quantization target point to the set of existing points.
:param nncf_graph: NNCFGraph instance for working with the graph and nodes.
:param quantization_point: SingleConfigQuantizationPoint for the needed layer.
:param nncf_graph: NNCFGraph instance for working with the graph and nodes.
"""
activation_quantization_target_point = self._get_activation_quantization_target_point(quantization_point)
activation_quantization_target_point = self._get_activation_quantization_target_point(
quantization_point, nncf_graph
)
self._quantization_target_points_to_qconfig[activation_quantization_target_point] = quantization_point.qconfig

def _get_weight_quantization_target_points(
Expand All @@ -684,12 +687,13 @@ def _get_weight_quantization_target_points(
return weight_quantization_target_points

def _get_activation_quantization_target_point(
self, quantization_point: SingleConfigQuantizationPoint
self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph
) -> SingleConfigQuantizationPoint:
"""
Returns activation quantization target point to the set of existing points.
:param quantization_point: SingleConfigQuantizationPoint for the needed layer.
:param nncf_graph: NNCFGraph instance for working with the graph and nodes.
:return: SingleConfigQuantizationPoint for the needed layer.
"""
node_name = quantization_point.insertion_point.target_node_name
Expand All @@ -701,7 +705,18 @@ def _get_activation_quantization_target_point(
)
# If quantization of node's output or Model Input node
else:
output_port_id = 0
# NOTE: Assumes that the operation has output edges only from one output port because
# we haven't encountered a model with operations that have multiple output edges with different
# output port IDs. Currently, such models are not supported. Usually, `output_port_id = 0` is used.
# However, there are operations, such as LSTMSequence, where the `output_port_id` changes from case
# to case. Therefore, the code below is required to dynamically determine the `output_port_id` where
# the quantize operation should be inserted."
node = nncf_graph.get_node_by_name(node_name)
unique_output_port_ids = set(e.output_port_id for e in nncf_graph.get_output_edges(node))
if len(unique_output_port_ids) > 1:
raise nncf.InternalError(f"Cannot determine the output_port_id for the operation: {node_name}")
output_port_id = next(iter(unique_output_port_ids))

activation_quantization_target_point = self._backend_entity.target_point(
TargetType.POST_LAYER_OPERATION, node_name, output_port_id
)
Expand Down Expand Up @@ -743,7 +758,7 @@ def _find_quantization_target_points(
if quantization_point.is_weight_quantization_point():
self._add_weight_quantization_target_point(quantization_point, nncf_graph)
elif quantization_point.is_activation_quantization_point():
self._add_activation_quantization_target_point(quantization_point)
self._add_activation_quantization_target_point(quantization_point, nncf_graph)
else:
raise nncf.InternalError("Incorrect quantization point")
return self._quantization_target_points_to_qconfig, self._unified_scale_groups
Expand Down Expand Up @@ -783,7 +798,9 @@ def _collect_unified_groups(

# Only activation quantizers can be unified
if quantization_point.is_activation_quantization_point():
activation_target_point = self._get_activation_quantization_target_point(quantization_point)
activation_target_point = self._get_activation_quantization_target_point(
quantization_point, nncf_graph
)
unified_scale_group.append(activation_target_point)
else:
weight_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph)
Expand Down Expand Up @@ -1096,8 +1113,8 @@ def _is_node_after_producers(node):
# In the case of the two quantizers without the branching after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
fq_2_prod_shape = np.prod(nncf_graph.get_output_edges(fq_2_producer)[0].tensor_shape)
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges_by_port_id(fq_1_producer, 0)[0].tensor_shape)
fq_2_prod_shape = np.prod(nncf_graph.get_output_edges_by_port_id(fq_2_producer, 0)[0].tensor_shape)

# Then it needs to remove quantizer with the smallest shape.
if fq_1_prod_shape >= fq_2_prod_shape:
Expand Down
10 changes: 8 additions & 2 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def create_convert_insertion_command(
def get_target_point_shape(nncf_graph: NNCFGraph, node: NNCFNode, target_point: OVTargetPoint) -> Tuple[int, ...]:
if target_point.is_weight_target_point():
return node.layer_attributes.constant_attributes[target_point.port_id]["shape"]

if target_point.type == TargetType.PRE_LAYER_OPERATION:
return nncf_graph.get_input_edges(node)[target_point.port_id].tensor_shape
edge = nncf_graph.get_input_edge_by_port_id(node, target_point.port_id)
return edge.tensor_shape
elif target_point.type == TargetType.POST_LAYER_OPERATION:
return nncf_graph.get_output_edges(node)[target_point.port_id].tensor_shape
# NOTE: Assumes that all output edges for the `node` with `output_port_id`
# equal to `target_point.port_id` should have the same `tensor_shape` value.
edges = nncf_graph.get_output_edges_by_port_id(node, target_point.port_id)
return edges[0].tensor_shape

raise NotImplementedError(f"Unsupported target point type {target_point.type}.")

@staticmethod
Expand Down
8 changes: 3 additions & 5 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def _group_nodes_by_source(self, nodes_to_smooth: List[Dict], nncf_graph: NNCFGr
for node_data in nodes_to_smooth:
node_to_smooth = node_data["node_to_smooth"]
input_act_port = node_data["input_act_port"]

source_node = nncf_graph.get_input_edges(node_to_smooth)[input_act_port].from_node
source_node = nncf_graph.get_input_edge_by_port_id(node_to_smooth, input_act_port).from_node
edge = nncf_graph.get_edge(source_node, node_to_smooth)
# Such group_id (with node, ports, and shape as a hash) allows us to be confident
# that all sensitive parameters are equal for successor nodes are equal.
Expand Down Expand Up @@ -288,8 +287,7 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[
continue

activation_port_id = self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph)
input_edges = nncf_graph.get_input_edges(node_with_weight)
activation_node = input_edges[activation_port_id].from_node
activation_node = nncf_graph.get_input_edge_by_port_id(node_with_weight, activation_port_id).from_node

# Skipping agnostic layers as inputs to propagate quantizer
# Only for Convolution layers
Expand Down Expand Up @@ -367,7 +365,7 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
:param input_port: Specified input port id.
:return: Calculated reduction axes.
"""
shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape
shape = nncf_graph.get_input_edge_by_port_id(node, input_port).tensor_shape
reduction_axes = tuple([])
if len(shape) > 1:
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node)
weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node
weight_node = nncf_graph.get_input_edge_by_port_id(node, weight_port_id).from_node
return len(nncf_graph.get_next_nodes(weight_node)) > 1

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -
:return: Tuple with the activation node and port id.
"""
activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph)
activation_edge = nncf_graph.get_input_edges(node)[activation_port]
activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port)
activation_node = activation_edge.from_node
port_id = activation_edge.output_port_id
return activation_node, port_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
]
]
},
"LSTMSequence/fq_output_0": {
"LSTMSequence/fq_output_1": {
"input_low": 0.0,
"input_high": 0.23005808889865875,
"input_high": 0.12186191976070404,
"output_low": 0.0,
"output_high": 0.23005808889865875
"output_high": 0.12186191976070404
},
"LSTMSequence/fq_weights_5": {
"input_low": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
]
]
},
"LSTMSequence/fq_output_0": {
"LSTMSequence/fq_output_1": {
"input_low": 0.0,
"input_high": 0.23005808889865875,
"input_high": 0.12186191976070404,
"output_low": 0.0,
"output_high": 0.23005808889865875
"output_high": 0.12186191976070404
},
"LSTMSequence/fq_weights_5": {
"input_low": [
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _create_ov_model(self):
x, initial_hidden_state, initial_cell_state, seq_len, W, R, B, 128, "FORWARD", name="LSTMSequence"
)
data = self._rng.random((1, 1, 128, 3)).astype(np.float32)
matmul = opset.matmul(lstm.output(0), data, transpose_a=False, transpose_b=False, name="MatMul")
matmul = opset.matmul(lstm.output(1), data, transpose_a=False, transpose_b=False, name="MatMul")

result = opset.result(matmul, name="Result")
result.get_output_tensor(0).set_names(set(["Result"]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_get_stat_collector(
port_id = None if TargetType.POST_LAYER_OPERATION else 0
ip = ActivationQuantizationInsertionPoint(params.target_node_name, port_id)
qp = SingleConfigQuantizationPoint(ip, q_config, [params.target_node_name])
min_max_algo._add_activation_quantization_target_point(qp)
min_max_algo._add_activation_quantization_target_point(qp, conv_sum_aggregation_nncf_graph.nncf_graph)
else:
ip = WeightQuantizationInsertionPoint(params.target_node_name)
qp = SingleConfigQuantizationPoint(ip, q_config, [params.target_node_name])
Expand Down

0 comments on commit 380d2bb

Please sign in to comment.