Skip to content

Commit

Permalink
Weights and flops calculation (openvinotoolkit#738)
Browse files Browse the repository at this point in the history
* add weights counter. Format statistics printing. Fix flops number update

Co-authored-by: Andrey Churkin <andrey.churkin@intel.com>
  • Loading branch information
evgeniya-egupova and andrey-churkin committed Jun 17, 2021
1 parent 91df215 commit 631cebe
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 93 deletions.
13 changes: 12 additions & 1 deletion nncf/common/pruning/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,28 +101,39 @@ def __init__(self,
model_statistics: PrunedModelStatistics,
full_flops: int,
current_flops: int,
full_params_num: int,
current_params_num: int,
target_pruning_level: float):
"""
Initializes statistics of the filter pruning algorithm.
:param model_statistics: Statistics of the pruned model.
:param full_flops: Full FLOPS.
:param current_flops: Current FLOPS.
:param full_params_num: Full number of weights.
:param current_params_num: Current number of weights.
:param target_pruning_level: A target level of the pruning
for the algorithm for the current epoch.
"""
self._giga = 1e9
self._mega = 1e6
self.model_statistics = model_statistics
self.full_flops = full_flops
self.current_flops = current_flops
self.flops_pruning_level = 1 - self.current_flops / self.full_flops
self.full_params_num = full_params_num
self.current_params_num = current_params_num
self.target_pruning_level = target_pruning_level

def to_str(self) -> str:
algorithm_string = create_table(
header=['Statistic\'s name', 'Value'],
rows=[
['FLOPS pruning level', self.flops_pruning_level],
['FLOPS current / full', f'{self.current_flops} / {self.full_flops}'],
['GFLOPS current / full', f'{self.current_flops / self._giga:.3f} /'
f' {self.full_flops / self._giga:.3f}'],
['MParams current / full', f'{self.current_params_num / self._mega:.3f} /'
f' {self.full_params_num / self._mega:.3f}'],
['A target level of the pruning for the algorithm for the current epoch', self.target_pruning_level],
]
)
Expand Down
70 changes: 55 additions & 15 deletions nncf/common/pruning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,44 +242,84 @@ def get_cluster_next_nodes(graph: NNCFGraph, pruned_groups_info: Clusterization[
return next_nodes


def count_flops_for_nodes(graph: NNCFGraph,
input_shapes: Dict[NNCFNodeName, List[int]],
output_shapes: Dict[NNCFNodeName, List[int]],
conv_op_metatypes: List[Type[OperatorMetatype]],
linear_op_metatypes: List[Type[OperatorMetatype]],
input_channels: Dict[NNCFNodeName, int] = None,
output_channels: Dict[NNCFNodeName, int] = None) -> Dict[NNCFNodeName, int]:
def count_flops_and_weights(graph: NNCFGraph,
input_shapes: Dict[NNCFNodeName, List[int]],
output_shapes: Dict[NNCFNodeName, List[int]],
conv_op_metatypes: List[Type[OperatorMetatype]],
linear_op_metatypes: List[Type[OperatorMetatype]],
input_channels: Dict[NNCFNodeName, int] = None,
output_channels: Dict[NNCFNodeName, int] = None) -> Tuple[int, int]:
"""
Counts the number FLOPs in the model for convolution and fully connected layers.
Counts the number weights and FLOPs in the model for convolution and fully connected layers.
:param graph: NNCFGraph.
:param input_shapes: Dictionary of input dimension shapes for convolutions and
fully connected layers. E.g {node_name: (height, width)}
fully connected layers. E.g {node_name: (height, width)}
:param output_shapes: Dictionary of output dimension shapes for convolutions and
fully connected layers. E.g {node_name: (height, width)}
fully connected layers. E.g {node_name: (height, width)}
:param conv_op_metatypes: List of metatypes defining convolution operations.
:param linear_op_metatypes: List of metatypes defining linear/fully connected operations.
:param input_channels: Dictionary of input channels number in convolutions.
If not specified, taken from the graph. {node_name: channels_num}
If not specified, taken from the graph. {node_name: channels_num}
:param output_channels: Dictionary of output channels number in convolutions.
If not specified, taken from the graph. {node_name: channels_num}
If not specified, taken from the graph. {node_name: channels_num}
:return number of FLOPs for the model
number of weights (params) in the model
"""
flops_pers_node, weights_per_node = count_flops_and_weights_per_node(graph,
input_shapes, output_shapes,
conv_op_metatypes, linear_op_metatypes,
input_channels, output_channels)
return sum(flops_pers_node.values()), sum(weights_per_node.values())


def count_flops_and_weights_per_node(graph: NNCFGraph,
input_shapes: Dict[NNCFNodeName, List[int]],
output_shapes: Dict[NNCFNodeName, List[int]],
conv_op_metatypes: List[Type[OperatorMetatype]],
linear_op_metatypes: List[Type[OperatorMetatype]],
input_channels: Dict[NNCFNodeName, int] = None,
output_channels: Dict[NNCFNodeName, int] = None) -> \
Tuple[Dict[NNCFNodeName, int], Dict[NNCFNodeName, int]]:
"""
Counts the number weights and FLOPs per node in the model for convolution and fully connected layers.
:param graph: NNCFGraph.
:param input_shapes: Dictionary of input dimension shapes for convolutions and
fully connected layers. E.g {node_name: (height, width)}
:param output_shapes: Dictionary of output dimension shapes for convolutions and
fully connected layers. E.g {node_name: (height, width)}
:param conv_op_metatypes: List of metatypes defining convolution operations.
:param linear_op_metatypes: List of metatypes defining linear/fully connected operations.
:param input_channels: Dictionary of input channels number in convolutions.
If not specified, taken from the graph. {node_name: channels_num}
:param output_channels: Dictionary of output channels number in convolutions.
If not specified, taken from the graph. {node_name: channels_num}
:return Dictionary of FLOPs number {node_name: flops_num}
Dictionary of weights number {node_name: weights_num}
"""
flops = {}
weights = {}
input_channels = input_channels or {}
output_channels = output_channels or {}
for node in graph.get_nodes_by_metatypes(conv_op_metatypes):
name = node.node_name
num_in_channels = input_channels.get(name, node.layer_attributes.in_channels)
num_out_channels = output_channels.get(name, node.layer_attributes.out_channels)
flops[name] = 2 * np.prod(node.layer_attributes.kernel_size) * \
flops_numpy = 2 * np.prod(node.layer_attributes.kernel_size) * \
num_in_channels * num_out_channels * np.prod(output_shapes[name])
weights_numpy = np.prod(node.layer_attributes.kernel_size) * num_in_channels * num_out_channels
flops[name] = flops_numpy.astype(int).item()
weights[name] = weights_numpy.astype(int).item()

for node in graph.get_nodes_by_metatypes(linear_op_metatypes):
name = node.node_name
flops[name] = 2 * np.prod(input_shapes[name]) * np.prod(output_shapes[name])
flops_numpy = 2 * np.prod(input_shapes[name]) * np.prod(output_shapes[name])
weights_numpy = np.prod(input_shapes[name]) * np.prod(output_shapes[name])
flops[name] = flops_numpy.astype(int).item()
weights[name] = weights_numpy.astype(int).item()

return flops
return flops, weights


def calculate_in_out_channels_in_uniformly_pruned_model(pruning_groups: List[Cluster[PrunedLayerInfoBase]],
Expand Down
92 changes: 66 additions & 26 deletions nncf/tensorflow/pruning/filter_pruning/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from nncf.common.pruning.schedulers import PRUNING_SCHEDULERS
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.pruning.utils import calculate_in_out_channels_in_uniformly_pruned_model
from nncf.common.pruning.utils import count_flops_for_nodes
from nncf.common.pruning.utils import count_flops_and_weights
from nncf.common.pruning.utils import count_flops_and_weights_per_node
from nncf.common.pruning.utils import get_cluster_next_nodes
from nncf.common.pruning.utils import get_conv_in_out_channels
from nncf.common.pruning.utils import get_rounded_pruned_element_number
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(self,
self.pruning_quota = 0.9

self._nodes_flops = {} # type: Dict[NNCFNodeName, int]
self._nodes_params_num = {} # type: Dict[NNCFNodeName, int]
self._layers_in_channels = {}
self._layers_out_channels = {}
self._layers_in_shapes = {}
Expand All @@ -120,6 +122,8 @@ def __init__(self,
self._flops_count_init()
self.full_flops = sum(self._nodes_flops.values())
self.current_flops = self.full_flops
self.full_params_num = sum(self._nodes_params_num.values())
self.current_params_num = self.full_params_num

self._weights_normalizer = tensor_l2_normalizer # for all weights in common case
self._filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(params.get('weight_importance', 'L2'))
Expand All @@ -140,10 +144,11 @@ def loss(self) -> CompressionLoss:

def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
model_statistics = self._calculate_pruned_model_stats()

self._update_benchmark_statistics()
target_pruning_level = self.scheduler.current_pruning_level

stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops, target_pruning_level)
stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops,
self.full_params_num, self.current_params_num, target_pruning_level)

nncf_stats = NNCFStatistics()
nncf_stats.register('filter_pruning', stats)
Expand Down Expand Up @@ -225,11 +230,12 @@ def _flops_count_init(self):
self._layers_in_shapes[node.node_name] = in_shape
self._layers_out_shapes[node.node_name] = out_shape

self._nodes_flops = count_flops_for_nodes(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES)
self._nodes_flops, self._nodes_params_num = \
count_flops_and_weights_per_node(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES)

def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float):
nncf_logger.debug('Setting new binary masks for pruned layers.')
Expand Down Expand Up @@ -268,6 +274,9 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float):
if nncf_node.data['output_mask'] is not None:
self._set_operation_masks([layer], nncf_node.data['output_mask'])

# Calculate actual flops and weights number with new masks
self._update_benchmark_statistics()

def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float):
"""
Sets the binary mask values for layer groups according to the global pruning rate.
Expand Down Expand Up @@ -312,6 +321,9 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float):
if nncf_node.data['output_mask'] is not None:
self._set_operation_masks([layer], nncf_node.data['output_mask'])

# Calculate actual flops with new masks
self._update_benchmark_statistics()

def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
target_flops_pruning_rate: float):
"""
Expand Down Expand Up @@ -359,13 +371,13 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
for node_name in self._next_nodes[group_id]:
tmp_in_channels[node_name] -= 1

flops = sum(count_flops_for_nodes(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
input_channels=tmp_in_channels,
output_channels=tmp_out_channels,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES).values())
flops, params_num = count_flops_and_weights(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
input_channels=tmp_in_channels,
output_channels=tmp_out_channels,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES)
if flops <= target_flops:
# 3. Add masks to the graph and propagate them
for group in self._pruned_layer_groups_info.get_all_clusters():
Expand All @@ -378,6 +390,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,

# 4. Set binary masks to the model
self.current_flops = flops
self.current_params_num = params_num
nncf_sorted_nodes = self._original_graph.topological_sort()
for layer in wrapped_layers:
nncf_node = [n for n in nncf_sorted_nodes
Expand All @@ -404,34 +417,34 @@ def _find_uniform_pruning_rate_for_target_flops(self, target_flops_pruning_rate)
left, right = 0.0, 1.0
while abs(right - left) > error:
middle = (left + right) / 2
flops = self._calculate_flops_in_uniformly_pruned_model(middle)
flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(middle)
if flops < target_flops:
right = middle
else:
left = middle
flops = self._calculate_flops_in_uniformly_pruned_model(right)
flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(right)
if flops < target_flops:
self.current_flops = flops
self.current_params_num = params_num
return right
raise RuntimeError(f'Unable to prune the model to get the required '
f'pruning rate in flops = {target_flops_pruning_rate}')

def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate):
def _calculate_flops_and_weights_in_uniformly_pruned_model(self, pruning_rate):
tmp_in_channels, tmp_out_channels = \
calculate_in_out_channels_in_uniformly_pruned_model(
pruning_groups=self._pruned_layer_groups_info.get_all_clusters(),
pruning_rate=pruning_rate,
full_input_channels=self._layers_in_channels,
full_output_channels=self._layers_out_channels,
pruning_groups_next_nodes=self._next_nodes)
flops = sum(count_flops_for_nodes(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
input_channels=tmp_in_channels,
output_channels=tmp_out_channels,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES).values())
return flops
return count_flops_and_weights(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
input_channels=tmp_in_channels,
output_channels=tmp_out_channels,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES)

def _calculate_filters_importance_in_group(self, group: Cluster[PrunedLayerInfo]):
"""
Expand Down Expand Up @@ -459,6 +472,33 @@ def _calculate_filters_importance_in_group(self, group: Cluster[PrunedLayerInfo]

return cumulative_filters_importance

def _calculate_flops_and_weights_pruned_model_by_masks(self):
tmp_in_channels = self._layers_in_channels.copy()
tmp_out_channels = self._layers_out_channels.copy()

for group in self._pruned_layer_groups_info.get_all_clusters():
assert all(tmp_out_channels[group.elements[0].node_name] == tmp_out_channels[node.node_name] for node in
group.elements)
mask = self._original_graph.get_node_by_id(group.elements[0].nncf_node_id).data['output_mask']
new_out_channels_num = int(sum(mask))
num_of_sparse_elems = len(mask) - new_out_channels_num
for node in group.elements:
tmp_out_channels[node.node_name] = new_out_channels_num
# Prune in_channels in all next nodes of cluster
for node_name in self._next_nodes[group.id]:
tmp_in_channels[node_name] -= num_of_sparse_elems

return count_flops_and_weights(self._original_graph,
self._layers_in_shapes,
self._layers_out_shapes,
input_channels=tmp_in_channels,
output_channels=tmp_out_channels,
conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
linear_op_metatypes=LINEAR_LAYER_METATYPES)

def _update_benchmark_statistics(self):
self.current_flops, self.current_params_num = self._calculate_flops_and_weights_pruned_model_by_masks()

def _layer_filter_importance(self, layer: NNCFWrapper):
layer_metatype = get_keras_layer_metatype(layer)
if len(layer_metatype.weight_definitions) != 1:
Expand Down
Loading

0 comments on commit 631cebe

Please sign in to comment.