Skip to content

Commit

Permalink
[Torch] SE block HW fusing pattern is presented (openvinotoolkit#2177)
Browse files Browse the repository at this point in the history
### Changes

* SE block ignored pattern is presented for Torch backend

![image](https://github.com/openvinotoolkit/nncf/assets/74656388/2bbf0d3c-ac9d-4db0-9c28-ff6b1bf299af)
* SE block ignored pattern for OV backend is adjusted
* NON_PATTERN_NODE_WITH_TYPE is introduced in the graph matcher


### Reason for changes

To align FQ layout between TORCH and OV backend for efficientnet_b0
model
To fix FQ placement for SE block:
Before: 

![image](https://github.com/openvinotoolkit/nncf/assets/74656388/7e8b27a3-fafc-41f9-ace2-1945bce32b65)
After:

![image](https://github.com/openvinotoolkit/nncf/assets/74656388/e4479c90-bd4f-47fb-b2c2-acd76a685030)


### Related tickets

121647

### Tests
* test_non_pattern_node_with_type for the graph matcher

Post training quantization manual build 180

| Model | Backend | Metric name | Metric value | Metric diff | Num FQ |
RAM MiB | Quant. time | Total time | Status |

|-------|----------------------|-------------|--------------|-------------|---------|---------|-------------|------------|---------|
| 1 | timm/efficientnet_b0 | OV | Acc@1 | 0.7688 | -0.0042 | 92 | 1033 |
0:00:25 | 1:03:58 | |
| 2 | timm/efficientnet_b0 | TORCH | Acc@1 | 0.7680 | -0.0049 | 92 | 892
| 0:00:48 | 1:04:27 |


* pattern manager tests
  • Loading branch information
daniil-lyakhov committed Oct 24, 2023
1 parent f415d90 commit 5336405
Show file tree
Hide file tree
Showing 14 changed files with 875 additions and 687 deletions.
15 changes: 11 additions & 4 deletions nncf/common/graph/graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from nncf.common.graph.patterns import GraphPattern

ATTRS_TO_SKIP = [GraphPattern.LABEL_ATTR, GraphPattern.PATTERN_NODE_TO_EXCLUDE]


def _are_nodes_matched(node_1, node_2) -> bool:
for attr in node_2:
if attr == GraphPattern.LABEL_ATTR:
if attr in ATTRS_TO_SKIP:
continue
if attr == GraphPattern.METATYPE_ATTR:
# GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
Expand Down Expand Up @@ -103,7 +105,8 @@ def _is_subgraph_matching_strict(graph: nx.DiGraph, pattern: nx.DiGraph, subgrap

def _copy_subgraph_excluding_non_pattern_node(subgraph: Dict[str, str], pattern_graph: GraphPattern) -> Dict[str, str]:
"""
Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE.
Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE
or GraphPattern.PATTERN_NODE_TO_EXCLUDE.
:param subgraph: Subgraph
:param pattern_graph: A graph consists of patterns to match.
Expand All @@ -113,8 +116,12 @@ def _copy_subgraph_excluding_non_pattern_node(subgraph: Dict[str, str], pattern_
for node_from_graph, node_from_pattern in subgraph.items():
pattern_node = pattern_graph.graph.nodes[node_from_pattern]
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR, [])
if GraphPattern.NON_PATTERN_NODE_TYPE not in pattern_node_types:
output[node_from_graph] = node_from_pattern
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_node_types:
continue
if pattern_node.get(GraphPattern.PATTERN_NODE_TO_EXCLUDE, False):
continue
output[node_from_graph] = node_from_pattern

return output


Expand Down
3 changes: 2 additions & 1 deletion nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class GraphPattern:
NODE_TYPE_ATTR = "metatype"
ANY_PATTERN_NODE_TYPE = "ANY_PATTERN_NODE"
NON_PATTERN_NODE_TYPE = "NON_PATTERN_NODE"
PATTERN_NODE_TO_EXCLUDE = "PATTERN_NODE_TO_EXCLUDE"

def __init__(self):
self._graph = nx.DiGraph()
Expand Down Expand Up @@ -298,7 +299,6 @@ class HWFusedPatternNames(Enum):
NORMALIZE_L2_MULTIPLY = PatternDesc("normalize_l2_multiply")
SCALE_SHIFT = PatternDesc("scale_shift")
SHIFT_SCALE = PatternDesc("shift_scale")
SE_BLOCK = PatternDesc("se_block")
SOFTMAX_DIV = PatternDesc("softmax_div")

# ACTIVATIONS
Expand Down Expand Up @@ -396,5 +396,6 @@ class IgnoredPatternNames(Enum):
model_types=[ModelType.TRANSFORMER],
devices=[TargetDevice.ANY, TargetDevice.CPU, TargetDevice.GPU, TargetDevice.VPU],
)
SE_BLOCK = PatternDesc("se_block")
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
EQUAL_LOGICALNOT = PatternDesc("equal_logicalnot")
42 changes: 0 additions & 42 deletions nncf/openvino/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,48 +147,6 @@ def create_shift_scale() -> GraphPattern:
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
)
reduce_mean_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "REDUCE_MEAN", GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype}
)
linear_node_1 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_1 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_1 = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RELU, PRELU, SWISH",
GraphPattern.METATYPE_ATTR: [om.OVReluMetatype, om.OVPReluMetatype, om.OVSwishMetatype],
}
)
linear_node_2 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_2 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_2 = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SIGMOID", GraphPattern.METATYPE_ATTR: om.OVSigmoidMetatype}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SOFTMAX_DIV)
def create_softmax_div() -> GraphPattern:
pattern = GraphPattern()
Expand Down
51 changes: 51 additions & 0 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nncf.common.graph.patterns.patterns import IgnoredPatternNames
from nncf.common.utils.registry import Registry
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.metatypes.groups import LINEAR_OPERATIONS

OPENVINO_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS")

Expand Down Expand Up @@ -108,3 +109,53 @@ def create_equal_logicalnot() -> GraphPattern:

pattern.add_edge(equal_node, logical_not_node)
return pattern


@OPENVINO_IGNORED_PATTERNS.register(IgnoredPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
)
reduce_mean_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype,
GraphPattern.PATTERN_NODE_TO_EXCLUDE: True,
}
)
linear_node_1 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_1 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_1 = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RELU, PRELU, SWISH",
GraphPattern.METATYPE_ATTR: [om.OVReluMetatype, om.OVPReluMetatype, om.OVSwishMetatype],
}
)
linear_node_2 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_2 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_2 = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SIGMOID", GraphPattern.METATYPE_ATTR: om.OVSigmoidMetatype}
)
multiply_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "MULTIPLY",
GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype,
GraphPattern.PATTERN_NODE_TO_EXCLUDE: True,
}
)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern
125 changes: 125 additions & 0 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.patterns.patterns import IgnoredPatternNames
from nncf.common.utils.registry import Registry
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
from nncf.torch.graph.pattern_operations import LINEAR_OPERATIONS

PT_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS")

Expand Down Expand Up @@ -88,3 +90,126 @@ def create_multihead_attention_output() -> GraphPattern:
transpose_aliases=transpose_aliases,
)
return pattern


# pylint:disable=too-many-statements
@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
MEAN_OPERATIONS = {
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: ["avg_pool2d", "adaptive_avg_pool2d", "avg_pool3d", "adaptive_avg_pool3d", "mean"],
GraphPattern.PATTERN_NODE_TO_EXCLUDE: True,
}
SYGMOID_OPERATIONS = {
GraphPattern.LABEL_ATTR: "SIGMOID",
GraphPattern.METATYPE_ATTR: ["sigmoid", "hardsigmoid"],
}
MUL_OPERATION = {
GraphPattern.LABEL_ATTR: "MUL",
GraphPattern.METATYPE_ATTR: "__mul__",
GraphPattern.PATTERN_NODE_TO_EXCLUDE: True,
}

def get_se_block_pattern() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="NON_PATTERN_NODE", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

def get_se_block_with_bias_pattern() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="NON_PATTERN_NODE", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

RESHAPE_NODES = {
GraphPattern.LABEL_ATTR: "RESHAPE",
GraphPattern.METATYPE_ATTR: ["reshape", "view", "flatten", "unsqueeze"],
}

def get_se_block_with_reshape() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="NON_PATTERN_NODE", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

def get_se_block_with_bias_and_reshape() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="NON_PATTERN_NODE", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

main_pattern = GraphPattern()
main_pattern.add_pattern_alternative(get_se_block_pattern())
main_pattern.add_pattern_alternative(get_se_block_with_bias_pattern())
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern
25 changes: 25 additions & 0 deletions tests/common/graph/test_graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,28 @@ def test_not_match_edges_inside_pattern():
pattern.add_edge(node_1, node_3)
matches = find_subgraphs_matching_pattern(ref_graph, pattern)
assert matches == [["1", "2", "3"]]


def test_non_pattern_graph_with_type():
for match in [False, True]:
ref_graph = nx.DiGraph()
ref_graph.add_node("0", **{GraphPattern.METATYPE_ATTR: "0"})
ref_graph.add_node("1", **{GraphPattern.METATYPE_ATTR: "a" if match else "0"})
ref_graph.add_node("2", **{GraphPattern.METATYPE_ATTR: "b"})
ref_graph.add_node("3", **{GraphPattern.METATYPE_ATTR: "c"})
ref_graph.add_edge("0", "1")
ref_graph.add_edge("1", "2")
ref_graph.add_edge("2", "3")

pattern = GraphPattern()
node_1 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "a", GraphPattern.PATTERN_NODE_TO_EXCLUDE: True})
node_2 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "b"})
node_3 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "c"})
pattern.add_edge(node_1, node_2)
pattern.add_edge(node_2, node_3)

matches = find_subgraphs_matching_pattern(ref_graph, pattern)
if not match:
assert not matches
else:
assert matches == [["2", "3"]]
Loading

0 comments on commit 5336405

Please sign in to comment.