Skip to content

Commit

Permalink
Fix merging of pruning groups when node has several convolution paren…
Browse files Browse the repository at this point in the history
…ts (#915)
  • Loading branch information
evgeniya-egupova committed Sep 6, 2021
1 parent 49994f4 commit 356822b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 22 deletions.
22 changes: 11 additions & 11 deletions nncf/common/pruning/pruning_node_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nncf.common.pruning.utils import get_sources_of_node
from nncf.common.pruning.utils import get_first_nodes_of_type
from nncf.common.pruning.utils import get_last_nodes_of_type
from nncf.common.pruning.utils import get_previous_conv
from nncf.common.pruning.utils import get_previous_convs
from nncf.common.pruning.utils import is_grouped_conv
from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry
from nncf.common.pruning.model_analysis import ModelAnalyzer
Expand Down Expand Up @@ -140,11 +140,12 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
cluster_id = pruned_nodes_clusterization.get_cluster_containing_element(node.node_id).id

if is_depthwise_conv(node):
previous_conv = get_previous_conv(graph, node, self._prune_operations, stop_propagation_ops)
if previous_conv:
previous_conv_cluster_id = pruned_nodes_clusterization.get_cluster_containing_element(
previous_conv.node_id).id
pruned_nodes_clusterization.merge_clusters(cluster_id, previous_conv_cluster_id)
previous_convs = get_previous_convs(graph, node, self._prune_operations, stop_propagation_ops)
previous_clusters = [
pruned_nodes_clusterization.get_cluster_containing_element(node.node_id).id
for node in previous_convs
]
pruned_nodes_clusterization.merge_list_of_clusters([cluster_id] + previous_clusters)

# 5. Merge nodes into one cluster if some module forwards several times
multiforward_nodes = self._get_multiforward_nodes(graph)
Expand All @@ -154,16 +155,15 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
pruned_nodes_clusterization.merge_list_of_clusters(clusters_to_merge)

# Merge previous convolutions into one cluster
previous_convs = []
all_previous_convs = []
for node in list_of_nodes:
nncf_node = graph.get_node_by_id(node.node_id)
previous_conv = get_previous_conv(graph, nncf_node, self._prune_operations, stop_propagation_ops)
if previous_conv:
previous_convs.append(previous_conv)
previous_convs = get_previous_convs(graph, nncf_node, self._prune_operations, stop_propagation_ops)
all_previous_convs.extend(previous_convs)

previous_clusters = [
pruned_nodes_clusterization.get_cluster_containing_element(node.node_id).id
for node in previous_convs
for node in all_previous_convs
]
pruned_nodes_clusterization.merge_list_of_clusters(previous_clusters)

Expand Down
15 changes: 6 additions & 9 deletions nncf/common/pruning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,16 @@ def get_last_nodes_of_type(graph: NNCFGraph, op_types: List[str]) -> List[NNCFNo
return last_nodes_of_type


def get_previous_conv(graph: NNCFGraph, nncf_node: NNCFNode,
pruning_types: List[str], stop_propagation_ops: List[str]) -> Optional[NNCFNode]:
def get_previous_convs(graph: NNCFGraph, nncf_node: NNCFNode,
pruning_types: List[str], stop_propagation_ops: List[str]) -> Optional[NNCFNode]:
"""
Returns source convolution of the node. If the node has another source type or there is
more than one source - returns None.
Returns source convolutions of the node.
:return: Source convolution of node. If the node has another source type or there is more
than one source - returns None.
:return: List of source convolutions of node.
"""
sources = get_sources_of_node(nncf_node, graph, pruning_types + stop_propagation_ops)
if len(sources) == 1 and sources[0].node_type in pruning_types:
return sources[0]
return None
sources = [source for source in sources if source.node_type in pruning_types]
return sources


def get_conv_in_out_channels(graph: NNCFGraph):
Expand Down
26 changes: 26 additions & 0 deletions tests/torch/pruning/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,32 @@ def forward(self, x):
x = copy.copy(x)
return x


class DepthwiseConvolutionModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = create_conv(1, 512, 1, 1, 1)
self.conv4 = create_conv(1024, 512, 2, 1, 1)
for i in range(512):
self.conv1.weight.data[i] += i
self.conv4.weight.data[i] += i
self.conv2 = create_conv(512, 1024, 3, 1, 1)
self.conv3 = create_conv(512, 1024, 3, 1, 1)
self.depthwise_conv = create_depthwise_conv(1024, 5, 1, 1)
for i in range(1024):
self.conv2.weight.data[i] += i
self.conv3.weight.data[i] += i
self.depthwise_conv.weight.data[i] += i

def forward(self, x):
x = self.conv1(x)
x1 = self.conv2(x)
x2 = self.conv3(x)
x = x1 + x2
x = self.depthwise_conv(x)
return self.conv4(x)


def get_basic_pruning_config(input_sample_size=None) -> NNCFConfig:
if input_sample_size is None:
input_sample_size = [1, 1, 4, 4]
Expand Down
16 changes: 14 additions & 2 deletions tests/torch/pruning/test_model_pruning_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from collections import Counter
from typing import Callable
from typing import Dict
from typing import List
Expand All @@ -35,6 +37,7 @@
from nncf.torch.pruning.filter_pruning.algo import FilterPruningBuilder
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import create_nncf_model_and_single_algo_builder
from tests.torch.pruning.helpers import DepthwiseConvolutionModel
from tests.torch.pruning.helpers import PruningTestModelEltwise
from tests.torch.pruning.helpers import PruningTestModelSharedConvs
from tests.torch.pruning.helpers import TestModelBranching
Expand Down Expand Up @@ -134,6 +137,14 @@ def __init__(self, model: Type[torch.nn.Module],
pruned_groups=[['PruningTestModelSharedConvs/NNCFConv2d[conv2]/conv2d_0',
'PruningTestModelSharedConvs/NNCFConv2d[conv2]/conv2d_1']],
pruned_groups_by_node_id=[[3, 4]],
prune_params=(False, False, False)),
GroupPruningModulesTestStruct(model=DepthwiseConvolutionModel,
non_pruned_module_nodes=['DepthwiseConvolutionModel/NNCFConv2d[conv1]/conv2d_0',
'DepthwiseConvolutionModel/NNCFConv2d[conv4]/conv2d_0'],
pruned_groups=[['DepthwiseConvolutionModel/NNCFConv2d[conv2]/conv2d_0',
'DepthwiseConvolutionModel/NNCFConv2d[conv3]/conv2d_0',
'DepthwiseConvolutionModel/NNCFConv2d[depthwise_conv]/conv2d_0']],
pruned_groups_by_node_id=[[2, 3, 5]],
prune_params=(False, False, False))
]

Expand Down Expand Up @@ -174,7 +185,7 @@ def test_groups(test_input_info_struct_: GroupPruningModulesTestStruct):
cluster_modules = [n.module for n in cluster.elements]
group_modules = [compressed_model.get_containing_module(node_name) for node_name in group]

assert cluster_modules == group_modules
assert Counter(cluster_modules) == Counter(group_modules)


def test_pruning_node_selector(test_input_info_struct_: GroupPruningModulesTestStruct):
Expand Down Expand Up @@ -215,7 +226,8 @@ def test_pruning_node_selector(test_input_info_struct_: GroupPruningModulesTestS
cluster_node_ids = [n.node_id for n in cluster.elements]
cluster_node_ids.sort()

assert cluster_node_ids == group_by_id
assert Counter(cluster_node_ids) == Counter(group_by_id)


class GroupSpecialModulesTestStruct:
def __init__(self, model: Callable, eltwise_clusters):
Expand Down

0 comments on commit 356822b

Please sign in to comment.