Skip to content

Commit

Permalink
Fix calculation of shared layers FLOPs (openvinotoolkit#677)
Browse files Browse the repository at this point in the history
* proper calculation of FLOPs of shared layers

* pytorch demo

* shapes calculation using graph edges

* changes related to wrappers renaming

* fix rebase

* change tests

* add test

* fix tests

* tests: sort nodes in cluster before comparing

* comments: remove redundant attribute

* fix rebase
  • Loading branch information
evgeniya-egupova committed May 28, 2021
1 parent d8a15c0 commit 1e47aea
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 273 deletions.
2 changes: 1 addition & 1 deletion beta/nncf/tensorflow/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _update_graph_with_raw_nodes(graph: NNCFGraph,
for original_name, instances in raw_nodes.items():
for i, attributes in instances.items():
node_name = get_expanded_node_name(original_name, i, attributes['is_shared'])
graph.add_node(node_name, original_name=original_name, **attributes)
graph.add_node(node_name, **attributes)

if attributes['is_output']:
# Aligning the structure of auxiliary output nodes is only necessary for NNCFGraph
Expand Down
8 changes: 4 additions & 4 deletions beta/nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@


class PrunedLayerInfo:
def __init__(self, layer_name: str, node_id: int):
def __init__(self, node_name: str, layer_name, node_id: int):
self.node_name = node_name
self.layer_name = layer_name
self.nncf_node_id = node_id
self.key = self.layer_name


class BasePruningAlgoBuilder(TFCompressionAlgorithmBuilder):
Expand Down Expand Up @@ -111,13 +111,14 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations = TFTransformationLayout()
shared_layers = set()

self._pruned_layer_groups_info = Clusterization('layer_name')
self._pruned_layer_groups_info = Clusterization('node_name')

for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
group_minfos = []
for node in group.nodes:
layer_name = get_layer_identifier(node)
layer = model.get_layer(layer_name)
group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id))

# Add output_mask to nodes to run mask_propagation
# and detect spec_nodes that will be pruned.
Expand All @@ -141,7 +142,6 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
self._get_insertion_command_binary_mask(
layer_name, node.metatype.bias_attr_name)
)
group_minfos.append(PrunedLayerInfo(layer_name, node.node_id))

cluster = NodesCluster(i, group_minfos, [n.node_id for n in group.nodes])
self._pruned_layer_groups_info.add_cluster(cluster)
Expand Down
86 changes: 44 additions & 42 deletions beta/nncf/tensorflow/pruning/filter_pruning/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from beta.nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES
from beta.nncf.tensorflow.graph.metatypes.matcher import get_keras_layer_metatype
from beta.nncf.tensorflow.graph.utils import collect_wrapped_layers
from beta.nncf.tensorflow.graph.utils import get_original_name
from beta.nncf.tensorflow.graph.utils import get_layer_identifier
from beta.nncf.tensorflow.graph.utils import get_original_name_and_instance_index
from beta.nncf.tensorflow.graph.utils import unwrap_layer
from beta.nncf.tensorflow.layers.data_layout import get_input_channel_axis
from beta.nncf.tensorflow.layers.wrapper import NNCFWrapper
Expand Down Expand Up @@ -175,40 +176,43 @@ def _init_pruned_layers_params(self):

# 3. Initialize pruning quotas
for cluster in self._pruned_layer_groups_info.get_all_clusters():
self._pruning_quotas[cluster.id] = floor(self._layers_out_channels[cluster.nodes[0].layer_name]
self._pruning_quotas[cluster.id] = floor(self._layers_out_channels[cluster.nodes[0].node_name]
* self.pruning_quota)

def _flops_count_init(self):
"""
Collects input/output shapes of convolutional and dense layers,
calculates corresponding layerwise FLOPs
"""
for layer in self._model.layers:
layer_metatype = get_keras_layer_metatype(layer)
for node in self._original_graph.get_nodes_by_metatypes(GENERAL_CONV_LAYER_METATYPES):
node_name, node_index = get_original_name_and_instance_index(node.node_name)
layer = self._model.get_layer(node_name)
layer_ = unwrap_layer(layer)

if layer_metatype in GENERAL_CONV_LAYER_METATYPES:
channel_axis = get_input_channel_axis(layer)
dims_slice = slice(channel_axis - layer_.rank, channel_axis) \
if layer.data_format == 'channels_last' else slice(channel_axis + 1, None)
in_shape = layer.get_input_shape_at(0)[dims_slice]
out_shape = layer.get_output_shape_at(0)[dims_slice]
channel_axis = get_input_channel_axis(layer)
dims_slice = slice(channel_axis - layer_.rank, channel_axis) \
if layer.data_format == 'channels_last' else slice(channel_axis + 1, None)
in_shape = layer.get_input_shape_at(node_index)[dims_slice]
out_shape = layer.get_output_shape_at(node_index)[dims_slice]

if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
raise RuntimeError(f'Input/output shape is not defined for layer `{layer.name}` ')
if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
raise RuntimeError(f'Input/output shape is not defined for layer `{layer.name}` ')

self._layers_in_shapes[layer.name] = in_shape
self._layers_out_shapes[layer.name] = out_shape
self._layers_in_shapes[node.node_name] = in_shape
self._layers_out_shapes[node.node_name] = out_shape

elif layer_metatype in LINEAR_LAYER_METATYPES:
in_shape = layer.get_input_shape_at(0)[1:]
out_shape = layer.get_output_shape_at(0)[1:]
for node in self._original_graph.get_nodes_by_metatypes(LINEAR_LAYER_METATYPES):
node_name, node_index = get_original_name_and_instance_index(node.node_name)
layer = self._model.get_layer(node_name)

if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
raise RuntimeError(f'Input/output shape is not defined for layer `{layer.name}` ')
in_shape = layer.get_input_shape_at(node_index)[1:]
out_shape = layer.get_output_shape_at(node_index)[1:]

self._layers_in_shapes[layer.name] = in_shape
self._layers_out_shapes[layer.name] = out_shape
if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
raise RuntimeError(f'Input/output shape is not defined for layer `{layer.name}` ')

self._layers_in_shapes[node.node_name] = in_shape
self._layers_out_shapes[node.node_name] = out_shape

self._nodes_flops = count_flops_for_nodes(self._original_graph,
self._layers_in_shapes,
Expand All @@ -227,8 +231,7 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float):
# 1. Calculate masks
for group in self._pruned_layer_groups_info.get_all_clusters():
# a. Calculate the cumulative importance for all filters in the group
cumulative_filters_importance = \
self._calculate_filters_importance_in_group(group, wrapped_layers)
cumulative_filters_importance = self._calculate_filters_importance_in_group(group)
filters_num = len(cumulative_filters_importance)

# b. Calculate threshold
Expand All @@ -250,7 +253,7 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float):
nncf_sorted_nodes = self._original_graph.topological_sort()
for layer in wrapped_layers:
nncf_node = [n for n in nncf_sorted_nodes
if layer.name == get_original_name(n.node_name)][0]
if layer.name == get_layer_identifier(n)][0]
if nncf_node.data['output_mask'] is not None:
self._set_operation_masks([layer], nncf_node.data['output_mask'])

Expand All @@ -272,8 +275,7 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float):
# 1. Calculate masks
# a. Calculate importances for all groups of filters
for group in self._pruned_layer_groups_info.get_all_clusters():
cumulative_filters_importance = \
self._calculate_filters_importance_in_group(group, wrapped_layers)
cumulative_filters_importance = self._calculate_filters_importance_in_group(group)
filter_importances[group.id] = cumulative_filters_importance

# b. Calculate one threshold for all weights
Expand All @@ -295,7 +297,7 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float):
nncf_sorted_nodes = self._original_graph.topological_sort()
for layer in wrapped_layers:
nncf_node = [n for n in nncf_sorted_nodes
if layer.name == get_original_name(n.node_name)][0]
if layer.name == get_layer_identifier(n)][0]
if nncf_node.data['output_mask'] is not None:
self._set_operation_masks([layer], nncf_node.data['output_mask'])

Expand All @@ -313,17 +315,15 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
nncf_sorted_nodes = self._original_graph.topological_sort()
for layer in wrapped_layers:
nncf_node = [n for n in nncf_sorted_nodes
if layer.name == get_original_name(n.node_name)][0]
if layer.name == get_layer_identifier(n)][0]
nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer))

# 1. Calculate importances for all groups of filters. Initialize masks.
filter_importances = []
group_indexes = []
filter_indexes = []
for group in self._pruned_layer_groups_info.get_all_clusters():
cumulative_filters_importance = \
self._calculate_filters_importance_in_group(group, wrapped_layers)

cumulative_filters_importance = self._calculate_filters_importance_in_group(group)
filter_importances.extend(cumulative_filters_importance)
filters_num = len(cumulative_filters_importance)
group_indexes.extend([group.id] * filters_num)
Expand All @@ -344,7 +344,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
# Update input/output shapes of pruned nodes
group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
for node in group.nodes:
tmp_out_channels[node.layer_name] -= 1
tmp_out_channels[node.node_name] -= 1
for node_name in self._next_nodes[group_id]:
tmp_in_channels[node_name] -= 1

Expand All @@ -370,7 +370,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
nncf_sorted_nodes = self._original_graph.topological_sort()
for layer in wrapped_layers:
nncf_node = [n for n in nncf_sorted_nodes
if layer.name == get_original_name(n.node_name)][0]
if layer.name == get_layer_identifier(n)][0]
if nncf_node.data['output_mask'] is not None:
self._set_operation_masks([layer], nncf_node.data['output_mask'])
return
Expand Down Expand Up @@ -422,26 +422,28 @@ def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate):
linear_op_metatypes=LINEAR_LAYER_METATYPES).values())
return flops

def _calculate_filters_importance_in_group(self, group: NodesCluster,
wrapped_layers: List[tf.keras.layers.Layer]):
def _calculate_filters_importance_in_group(self, group: NodesCluster):
"""
Calculates cumulative filters importance in the group.
:param group: Nodes cluster
:param wrapped_layers: List of keras nodes wrapped by NNCFWrapper
:return a list of filter importance scores
"""
group_layer_names = [node.layer_name for node in group.nodes]
group_filters_num = tf.constant([get_filters_num(wrapped_layer)
for wrapped_layer in wrapped_layers
if wrapped_layer.name in group_layer_names])
group_layers = [self._model.get_layer(node.layer_name) for node in group.nodes]
group_filters_num = tf.constant([get_filters_num(layer) for layer in group_layers])
filters_num = group_filters_num[0]
assert tf.reduce_all(group_filters_num == filters_num)

cumulative_filters_importance = tf.zeros(filters_num)
# Calculate cumulative importance for all filters in this group
shared_nodes = []
for minfo in group.nodes:
layer = [layer for layer in wrapped_layers if layer.name == minfo.layer_name][0]
filters_importance = self._layer_filter_importance(layer)
layer_name = minfo.layer_name
if layer_name in shared_nodes:
continue
nncf_node = self._original_graph.get_node_by_id(minfo.nncf_node_id)
if nncf_node.data['is_shared']:
shared_nodes.append(layer_name)
filters_importance = self._layer_filter_importance(self._model.get_layer(layer_name))
cumulative_filters_importance += filters_importance

return cumulative_filters_importance
Expand Down
10 changes: 10 additions & 0 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ def get_input_edges(self, node: NNCFNode) -> Dict[Tuple[str, str], dict]:

return OrderedDict((edge, self._nx_graph.edges[edge]) for edge in input_edges)

def get_output_edges(self, node: NNCFNode) -> Dict[Tuple[str, str], dict]:
"""
Returns edges of output tensors with description. Unordered.
:param node: Producer node.
:return: Dictionary of output edges for the node.
"""
nx_node_key = self._node_id_to_key_dict[node.node_id]
return {edge: self._nx_graph.edges[edge] for edge in self._nx_graph.out_edges(nx_node_key)}

def traverse_graph(self,
curr_node: NNCFNode,
traverse_function: Callable[[NNCFNode, List[Any]], Tuple[bool, List[Any]]],
Expand Down
39 changes: 13 additions & 26 deletions nncf/common/pruning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import math

from functools import partial
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type

import numpy as np

Expand Down Expand Up @@ -193,10 +193,6 @@ def get_previous_conv(graph: NNCFGraph, nncf_node: NNCFNode,
return None


def get_original_node_name(node_name: str):
return node_name.split('^')[0]


def get_conv_in_out_channels(graph: NNCFGraph):
"""
Collects the number of input and output channels for each convolution in the graph.
Expand All @@ -210,8 +206,7 @@ def get_conv_in_out_channels(graph: NNCFGraph):
in_channels, out_channels = {}, {}
for node in graph.get_all_nodes():
if isinstance(node.module_attributes, ConvolutionModuleAttributes):
name = node.ia_op_exec_context.scope_in_model if hasattr(node, 'ia_op_exec_context') \
else get_original_node_name(node.node_name)
name = node.node_name
if name in in_channels and name in out_channels:
continue
in_channels[name] = node.module_attributes.in_channels
Expand All @@ -235,14 +230,10 @@ def get_cluster_next_nodes(graph: NNCFGraph, pruned_groups_info,
cluster_nodes = set()
for cluster_node in cluster.nodes:
nncf_cluster_node = graph.get_node_by_id(cluster_node.nncf_node_id)
nncf_cluster_node_scope = nncf_cluster_node.ia_op_exec_context.scope_in_model \
if hasattr(nncf_cluster_node, 'ia_op_exec_context') \
else get_original_node_name(nncf_cluster_node.node_name)
cluster_nodes.add(nncf_cluster_node_scope)
cluster_nodes.add(nncf_cluster_node.node_name)
curr_next_nodes = get_next_nodes_of_types(graph, nncf_cluster_node, prunable_types)

next_nodes_idxs = [n.ia_op_exec_context.scope_in_model if hasattr(n, 'ia_op_exec_context')
else get_original_node_name(n.node_name) for n in curr_next_nodes]
next_nodes_idxs = [n.node_name for n in curr_next_nodes]
next_nodes_cluster = next_nodes_cluster.union(next_nodes_idxs)
next_nodes[cluster.id] = list(next_nodes_cluster - cluster_nodes)
return next_nodes
Expand All @@ -262,8 +253,8 @@ def count_flops_for_nodes(graph: NNCFGraph,
fully connected layers. E.g {node_name: (height, width)}
:param output_shapes: Dictionary of output dimension shapes for convolutions and
fully connected layers. E.g {node_name: (height, width)}
:param conv_op_types: List of metatypes defining convolution operations.
:param linear_op_types: List of metatypes defining linear/fully connected operations.
:param conv_op_metatypes: List of metatypes defining convolution operations.
:param linear_op_metatypes: List of metatypes defining linear/fully connected operations.
:param input_channels: Dictionary of input channels number in convolutions.
If not specified, taken from the graph. {node_name: channels_num}
:param output_channels: Dictionary of output channels number in convolutions.
Expand All @@ -274,26 +265,22 @@ def count_flops_for_nodes(graph: NNCFGraph,
input_channels = input_channels or {}
output_channels = output_channels or {}
for node in graph.get_nodes_by_metatypes(conv_op_metatypes):
name = node.ia_op_exec_context.scope_in_model if hasattr(node, 'ia_op_exec_context') \
else get_original_node_name(node.node_name)
if name in flops:
continue
name = node.node_name
num_in_channels = input_channels.get(name, node.module_attributes.in_channels)
num_out_channels = output_channels.get(name, node.module_attributes.out_channels)
flops[name] = 2 * np.prod(node.module_attributes.kernel_size) * \
num_in_channels * num_out_channels * np.prod(output_shapes[name])

for node in graph.get_nodes_by_metatypes(linear_op_metatypes):
name = node.ia_op_exec_context.scope_in_model if hasattr(node, 'ia_op_exec_context') \
else get_original_node_name(node.node_name)
name = node.node_name
flops[name] = 2 * np.prod(input_shapes[name]) * np.prod(output_shapes[name])

return flops


def calculate_in_out_channels_in_uniformly_pruned_model(pruning_groups, pruning_rate: float,
full_input_channels: Dict[Union[str, 'Scope'], int],
full_output_channels: Dict[Union[str, 'Scope'], int],
full_input_channels: Dict[str, int],
full_output_channels: Dict[str, int],
pruning_groups_next_nodes: Dict[int, List[str]]):
"""
Imitates filters pruning by removing `pruning_rate` percent of output filters in each pruning group
Expand All @@ -311,16 +298,16 @@ def calculate_in_out_channels_in_uniformly_pruned_model(pruning_groups, pruning_
tmp_out_channels = full_output_channels.copy()

for group in pruning_groups:
layer_name = group.nodes[0].key
assert all(tmp_out_channels[layer_name] == tmp_out_channels[node.key] for node in
layer_name = group.nodes[0].node_name
assert all(tmp_out_channels[layer_name] == tmp_out_channels[node.node_name] for node in
group.nodes)
# Prune all nodes in cluster (by output channels)
old_out_channels = full_output_channels[layer_name]
num_of_sparse_elems = get_rounded_pruned_element_number(old_out_channels, pruning_rate)
new_out_channels_num = old_out_channels - num_of_sparse_elems

for node in group.nodes:
tmp_out_channels[node.key] = new_out_channels_num
tmp_out_channels[node.node_name] = new_out_channels_num

# Prune in_channels in all next nodes of cluster
for node_name in pruning_groups_next_nodes[group.id]:
Expand Down
Loading

0 comments on commit 1e47aea

Please sign in to comment.