diff --git a/nncf/common/pruning/statistics.py b/nncf/common/pruning/statistics.py index 41671eb1cd4..2ec6a9973b0 100644 --- a/nncf/common/pruning/statistics.py +++ b/nncf/common/pruning/statistics.py @@ -101,6 +101,8 @@ def __init__(self, model_statistics: PrunedModelStatistics, full_flops: int, current_flops: int, + full_params_num: int, + current_params_num: int, target_pruning_level: float): """ Initializes statistics of the filter pruning algorithm. @@ -108,13 +110,19 @@ def __init__(self, :param model_statistics: Statistics of the pruned model. :param full_flops: Full FLOPS. :param current_flops: Current FLOPS. + :param full_params_num: Full number of weights. + :param current_params_num: Current number of weights. :param target_pruning_level: A target level of the pruning for the algorithm for the current epoch. """ + self._giga = 1e9 + self._mega = 1e6 self.model_statistics = model_statistics self.full_flops = full_flops self.current_flops = current_flops self.flops_pruning_level = 1 - self.current_flops / self.full_flops + self.full_params_num = full_params_num + self.current_params_num = current_params_num self.target_pruning_level = target_pruning_level def to_str(self) -> str: @@ -122,7 +130,10 @@ def to_str(self) -> str: header=['Statistic\'s name', 'Value'], rows=[ ['FLOPS pruning level', self.flops_pruning_level], - ['FLOPS current / full', f'{self.current_flops} / {self.full_flops}'], + ['GFLOPS current / full', f'{self.current_flops / self._giga:.3f} /' + f' {self.full_flops / self._giga:.3f}'], + ['MParams current / full', f'{self.current_params_num / self._mega:.3f} /' + f' {self.full_params_num / self._mega:.3f}'], ['A target level of the pruning for the algorithm for the current epoch', self.target_pruning_level], ] ) diff --git a/nncf/common/pruning/utils.py b/nncf/common/pruning/utils.py index 35d97409264..2fb5cfef70a 100644 --- a/nncf/common/pruning/utils.py +++ b/nncf/common/pruning/utils.py @@ -242,44 +242,84 @@ def get_cluster_next_nodes(graph: NNCFGraph, pruned_groups_info: Clusterization[ return next_nodes -def count_flops_for_nodes(graph: NNCFGraph, - input_shapes: Dict[NNCFNodeName, List[int]], - output_shapes: Dict[NNCFNodeName, List[int]], - conv_op_metatypes: List[Type[OperatorMetatype]], - linear_op_metatypes: List[Type[OperatorMetatype]], - input_channels: Dict[NNCFNodeName, int] = None, - output_channels: Dict[NNCFNodeName, int] = None) -> Dict[NNCFNodeName, int]: +def count_flops_and_weights(graph: NNCFGraph, + input_shapes: Dict[NNCFNodeName, List[int]], + output_shapes: Dict[NNCFNodeName, List[int]], + conv_op_metatypes: List[Type[OperatorMetatype]], + linear_op_metatypes: List[Type[OperatorMetatype]], + input_channels: Dict[NNCFNodeName, int] = None, + output_channels: Dict[NNCFNodeName, int] = None) -> Tuple[int, int]: """ - Counts the number FLOPs in the model for convolution and fully connected layers. + Counts the number weights and FLOPs in the model for convolution and fully connected layers. :param graph: NNCFGraph. :param input_shapes: Dictionary of input dimension shapes for convolutions and - fully connected layers. E.g {node_name: (height, width)} + fully connected layers. E.g {node_name: (height, width)} :param output_shapes: Dictionary of output dimension shapes for convolutions and - fully connected layers. E.g {node_name: (height, width)} + fully connected layers. E.g {node_name: (height, width)} :param conv_op_metatypes: List of metatypes defining convolution operations. :param linear_op_metatypes: List of metatypes defining linear/fully connected operations. :param input_channels: Dictionary of input channels number in convolutions. - If not specified, taken from the graph. {node_name: channels_num} + If not specified, taken from the graph. {node_name: channels_num} :param output_channels: Dictionary of output channels number in convolutions. - If not specified, taken from the graph. {node_name: channels_num} + If not specified, taken from the graph. {node_name: channels_num} + :return number of FLOPs for the model + number of weights (params) in the model + """ + flops_pers_node, weights_per_node = count_flops_and_weights_per_node(graph, + input_shapes, output_shapes, + conv_op_metatypes, linear_op_metatypes, + input_channels, output_channels) + return sum(flops_pers_node.values()), sum(weights_per_node.values()) + + +def count_flops_and_weights_per_node(graph: NNCFGraph, + input_shapes: Dict[NNCFNodeName, List[int]], + output_shapes: Dict[NNCFNodeName, List[int]], + conv_op_metatypes: List[Type[OperatorMetatype]], + linear_op_metatypes: List[Type[OperatorMetatype]], + input_channels: Dict[NNCFNodeName, int] = None, + output_channels: Dict[NNCFNodeName, int] = None) -> \ + Tuple[Dict[NNCFNodeName, int], Dict[NNCFNodeName, int]]: + """ + Counts the number weights and FLOPs per node in the model for convolution and fully connected layers. + + :param graph: NNCFGraph. + :param input_shapes: Dictionary of input dimension shapes for convolutions and + fully connected layers. E.g {node_name: (height, width)} + :param output_shapes: Dictionary of output dimension shapes for convolutions and + fully connected layers. E.g {node_name: (height, width)} + :param conv_op_metatypes: List of metatypes defining convolution operations. + :param linear_op_metatypes: List of metatypes defining linear/fully connected operations. + :param input_channels: Dictionary of input channels number in convolutions. + If not specified, taken from the graph. {node_name: channels_num} + :param output_channels: Dictionary of output channels number in convolutions. + If not specified, taken from the graph. {node_name: channels_num} :return Dictionary of FLOPs number {node_name: flops_num} + Dictionary of weights number {node_name: weights_num} """ flops = {} + weights = {} input_channels = input_channels or {} output_channels = output_channels or {} for node in graph.get_nodes_by_metatypes(conv_op_metatypes): name = node.node_name num_in_channels = input_channels.get(name, node.layer_attributes.in_channels) num_out_channels = output_channels.get(name, node.layer_attributes.out_channels) - flops[name] = 2 * np.prod(node.layer_attributes.kernel_size) * \ + flops_numpy = 2 * np.prod(node.layer_attributes.kernel_size) * \ num_in_channels * num_out_channels * np.prod(output_shapes[name]) + weights_numpy = np.prod(node.layer_attributes.kernel_size) * num_in_channels * num_out_channels + flops[name] = flops_numpy.astype(int).item() + weights[name] = weights_numpy.astype(int).item() for node in graph.get_nodes_by_metatypes(linear_op_metatypes): name = node.node_name - flops[name] = 2 * np.prod(input_shapes[name]) * np.prod(output_shapes[name]) + flops_numpy = 2 * np.prod(input_shapes[name]) * np.prod(output_shapes[name]) + weights_numpy = np.prod(input_shapes[name]) * np.prod(output_shapes[name]) + flops[name] = flops_numpy.astype(int).item() + weights[name] = weights_numpy.astype(int).item() - return flops + return flops, weights def calculate_in_out_channels_in_uniformly_pruned_model(pruning_groups: List[Cluster[PrunedLayerInfoBase]], diff --git a/nncf/tensorflow/pruning/filter_pruning/algorithm.py b/nncf/tensorflow/pruning/filter_pruning/algorithm.py index ad3fbe32fd2..a157dbe0d39 100644 --- a/nncf/tensorflow/pruning/filter_pruning/algorithm.py +++ b/nncf/tensorflow/pruning/filter_pruning/algorithm.py @@ -27,7 +27,8 @@ from nncf.common.pruning.schedulers import PRUNING_SCHEDULERS from nncf.common.pruning.statistics import FilterPruningStatistics from nncf.common.pruning.utils import calculate_in_out_channels_in_uniformly_pruned_model -from nncf.common.pruning.utils import count_flops_for_nodes +from nncf.common.pruning.utils import count_flops_and_weights +from nncf.common.pruning.utils import count_flops_and_weights_per_node from nncf.common.pruning.utils import get_cluster_next_nodes from nncf.common.pruning.utils import get_conv_in_out_channels from nncf.common.pruning.utils import get_rounded_pruned_element_number @@ -110,6 +111,7 @@ def __init__(self, self.pruning_quota = 0.9 self._nodes_flops = {} # type: Dict[NNCFNodeName, int] + self._nodes_params_num = {} # type: Dict[NNCFNodeName, int] self._layers_in_channels = {} self._layers_out_channels = {} self._layers_in_shapes = {} @@ -120,6 +122,8 @@ def __init__(self, self._flops_count_init() self.full_flops = sum(self._nodes_flops.values()) self.current_flops = self.full_flops + self.full_params_num = sum(self._nodes_params_num.values()) + self.current_params_num = self.full_params_num self._weights_normalizer = tensor_l2_normalizer # for all weights in common case self._filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(params.get('weight_importance', 'L2')) @@ -140,10 +144,11 @@ def loss(self) -> CompressionLoss: def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: model_statistics = self._calculate_pruned_model_stats() - + self._update_benchmark_statistics() target_pruning_level = self.scheduler.current_pruning_level - stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops, target_pruning_level) + stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops, + self.full_params_num, self.current_params_num, target_pruning_level) nncf_stats = NNCFStatistics() nncf_stats.register('filter_pruning', stats) @@ -225,11 +230,12 @@ def _flops_count_init(self): self._layers_in_shapes[node.node_name] = in_shape self._layers_out_shapes[node.node_name] = out_shape - self._nodes_flops = count_flops_for_nodes(self._original_graph, - self._layers_in_shapes, - self._layers_out_shapes, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES) + self._nodes_flops, self._nodes_params_num = \ + count_flops_and_weights_per_node(self._original_graph, + self._layers_in_shapes, + self._layers_out_shapes, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float): nncf_logger.debug('Setting new binary masks for pruned layers.') @@ -268,6 +274,9 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float): if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask']) + # Calculate actual flops and weights number with new masks + self._update_benchmark_statistics() + def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float): """ Sets the binary mask values for layer groups according to the global pruning rate. @@ -312,6 +321,9 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float): if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask']) + # Calculate actual flops with new masks + self._update_benchmark_statistics() + def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, target_flops_pruning_rate: float): """ @@ -359,13 +371,13 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, for node_name in self._next_nodes[group_id]: tmp_in_channels[node_name] -= 1 - flops = sum(count_flops_for_nodes(self._original_graph, - self._layers_in_shapes, - self._layers_out_shapes, - input_channels=tmp_in_channels, - output_channels=tmp_out_channels, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES).values()) + flops, params_num = count_flops_and_weights(self._original_graph, + self._layers_in_shapes, + self._layers_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) if flops <= target_flops: # 3. Add masks to the graph and propagate them for group in self._pruned_layer_groups_info.get_all_clusters(): @@ -378,6 +390,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, # 4. Set binary masks to the model self.current_flops = flops + self.current_params_num = params_num nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [n for n in nncf_sorted_nodes @@ -404,19 +417,20 @@ def _find_uniform_pruning_rate_for_target_flops(self, target_flops_pruning_rate) left, right = 0.0, 1.0 while abs(right - left) > error: middle = (left + right) / 2 - flops = self._calculate_flops_in_uniformly_pruned_model(middle) + flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(middle) if flops < target_flops: right = middle else: left = middle - flops = self._calculate_flops_in_uniformly_pruned_model(right) + flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(right) if flops < target_flops: self.current_flops = flops + self.current_params_num = params_num return right raise RuntimeError(f'Unable to prune the model to get the required ' f'pruning rate in flops = {target_flops_pruning_rate}') - def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate): + def _calculate_flops_and_weights_in_uniformly_pruned_model(self, pruning_rate): tmp_in_channels, tmp_out_channels = \ calculate_in_out_channels_in_uniformly_pruned_model( pruning_groups=self._pruned_layer_groups_info.get_all_clusters(), @@ -424,14 +438,13 @@ def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate): full_input_channels=self._layers_in_channels, full_output_channels=self._layers_out_channels, pruning_groups_next_nodes=self._next_nodes) - flops = sum(count_flops_for_nodes(self._original_graph, - self._layers_in_shapes, - self._layers_out_shapes, - input_channels=tmp_in_channels, - output_channels=tmp_out_channels, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES).values()) - return flops + return count_flops_and_weights(self._original_graph, + self._layers_in_shapes, + self._layers_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) def _calculate_filters_importance_in_group(self, group: Cluster[PrunedLayerInfo]): """ @@ -459,6 +472,33 @@ def _calculate_filters_importance_in_group(self, group: Cluster[PrunedLayerInfo] return cumulative_filters_importance + def _calculate_flops_and_weights_pruned_model_by_masks(self): + tmp_in_channels = self._layers_in_channels.copy() + tmp_out_channels = self._layers_out_channels.copy() + + for group in self._pruned_layer_groups_info.get_all_clusters(): + assert all(tmp_out_channels[group.elements[0].node_name] == tmp_out_channels[node.node_name] for node in + group.elements) + mask = self._original_graph.get_node_by_id(group.elements[0].nncf_node_id).data['output_mask'] + new_out_channels_num = int(sum(mask)) + num_of_sparse_elems = len(mask) - new_out_channels_num + for node in group.elements: + tmp_out_channels[node.node_name] = new_out_channels_num + # Prune in_channels in all next nodes of cluster + for node_name in self._next_nodes[group.id]: + tmp_in_channels[node_name] -= num_of_sparse_elems + + return count_flops_and_weights(self._original_graph, + self._layers_in_shapes, + self._layers_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) + + def _update_benchmark_statistics(self): + self.current_flops, self.current_params_num = self._calculate_flops_and_weights_pruned_model_by_masks() + def _layer_filter_importance(self, layer: NNCFWrapper): layer_metatype = get_keras_layer_metatype(layer) if len(layer_metatype.weight_definitions) != 1: diff --git a/nncf/torch/pruning/filter_pruning/algo.py b/nncf/torch/pruning/filter_pruning/algo.py index 74f5688e987..70c00010c5c 100644 --- a/nncf/torch/pruning/filter_pruning/algo.py +++ b/nncf/torch/pruning/filter_pruning/algo.py @@ -11,9 +11,7 @@ limitations under the License. """ -from typing import Dict -from typing import List -from typing import Union +from typing import Dict, List, Tuple, Union import torch from texttable import Texttable @@ -41,7 +39,8 @@ from nncf.common.pruning.statistics import PrunedModelStatistics from nncf.common.pruning.statistics import FilterPruningStatistics from nncf.common.pruning.utils import calculate_in_out_channels_in_uniformly_pruned_model -from nncf.common.pruning.utils import count_flops_for_nodes +from nncf.common.pruning.utils import count_flops_and_weights +from nncf.common.pruning.utils import count_flops_and_weights_per_node from nncf.common.pruning.utils import get_cluster_next_nodes from nncf.common.pruning.utils import get_conv_in_out_channels from nncf.common.pruning.utils import get_rounded_pruned_element_number @@ -124,11 +123,14 @@ def __init__(self, target_model: NNCFNetwork, self._modules_out_shapes = {} # type: Dict[NNCFNodeName, List[int]] self.pruning_quotas = {} self.nodes_flops = {} # type: Dict[NNCFNodeName, int] + self.nodes_params_num = {} # type: Dict[NNCFNodeName, int] self.next_nodes = {} # type: Dict[int, List[NNCFNodeName]] self._init_pruned_modules_params() self.flops_count_init() self.full_flops = sum(self.nodes_flops.values()) self.current_flops = self.full_flops + self.full_params_num = sum(self.nodes_params_num.values()) + self.current_params_num = self.full_params_num self.weights_normalizer = tensor_l2_normalizer # for all weights in common case self.filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(params.get('weight_importance', 'L2')) @@ -165,10 +167,12 @@ def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: self.pruning_rate_for_filters(minfo)) model_statistics = PrunedModelStatistics(self._pruning_rate, list(pruned_layers_summary.values())) - + self._update_benchmark_statistics() target_pruning_level = self.scheduler.current_pruning_level - stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops, target_pruning_level) + stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops, + self.full_params_num, self.current_params_num, target_pruning_level) + nncf_stats = NNCFStatistics() nncf_stats.register('filter_pruning', stats) @@ -217,13 +221,14 @@ def flops_count_init(self) -> None: else: self._modules_in_shapes[node.node_name] = in_shape[1:] - self.nodes_flops = count_flops_for_nodes(graph, self._modules_in_shapes, self._modules_out_shapes, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES) + self.nodes_flops, self.nodes_params_num = \ + count_flops_and_weights_per_node(graph, self._modules_in_shapes, self._modules_out_shapes, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) - def _calculate_flops_pruned_model_by_masks(self) -> float: + def _calculate_flops_and_weights_pruned_model_by_masks(self) -> Tuple[int, int]: """ - Calculates number of flops for pruned model by using binary_filter_pruning_mask. + Calculates number of weights and flops for pruned model by using binary_filter_pruning_mask. :return: number of flops in model """ tmp_in_channels = self._modules_in_channels.copy() @@ -241,18 +246,18 @@ def _calculate_flops_pruned_model_by_masks(self) -> float: for node_name in next_nodes: tmp_in_channels[node_name] -= num_of_sparse_elems - flops = sum(count_flops_for_nodes(self._model.get_original_graph(), - self._modules_in_shapes, - self._modules_out_shapes, - input_channels=tmp_in_channels, - output_channels=tmp_out_channels, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES).values()) - return flops + return count_flops_and_weights(self._model.get_original_graph(), + self._modules_in_shapes, + self._modules_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) - def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate: float) -> float: + def _calculate_flops_and_weights_in_uniformly_pruned_model(self, pruning_rate: float) -> Tuple[int, int]: """ - Prune all prunable modules in model with pruning_rate rate and returns flops of pruned model. + Prune all prunable modules in model with pruning_rate rate and returns number of weights and + flops of the pruned model. :param pruning_rate: proportion of zero filters in all modules :return: flops number in pruned model """ @@ -264,14 +269,13 @@ def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate: float) -> flo full_output_channels=self._modules_out_channels, pruning_groups_next_nodes=self.next_nodes) - flops = sum(count_flops_for_nodes(self._model.get_original_graph(), - self._modules_in_shapes, - self._modules_out_shapes, - input_channels=tmp_in_channels, - output_channels=tmp_out_channels, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES).values()) - return flops + return count_flops_and_weights(self._model.get_original_graph(), + self._modules_in_shapes, + self._modules_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) def _find_uniform_pruning_rate_for_target_flops(self, target_flops_pruning_rate: float) -> float: """ @@ -285,14 +289,15 @@ def _find_uniform_pruning_rate_for_target_flops(self, target_flops_pruning_rate: left, right = 0.0, 1.0 while abs(right - left) > error: middle = (left + right) / 2 - flops = self._calculate_flops_in_uniformly_pruned_model(middle) + flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(middle) if flops < target_flops: right = middle else: left = middle - flops = self._calculate_flops_in_uniformly_pruned_model(right) + flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(right) if flops < target_flops: self.current_flops = flops + self.current_params_num = params_num return right raise RuntimeError("Can't prune the model to get the required " "pruning rate in flops = {}".format(target_flops_pruning_rate)) @@ -407,8 +412,8 @@ def _set_binary_masks_for_pruned_modules_groupwise(self, pruning_module = minfo.operand pruning_module.binary_filter_pruning_mask = mask - # Calculate actual flops with new masks - self.current_flops = self._calculate_flops_pruned_model_by_masks() + # Calculate actual flops and weights number with new masks + self._update_benchmark_statistics() def _set_binary_masks_for_pruned_modules_globally(self, pruning_rate: float) -> None: """ @@ -446,8 +451,8 @@ def _set_binary_masks_for_pruned_modules_globally(self, pruning_rate: float) -> pruning_module = minfo.operand pruning_module.binary_filter_pruning_mask = mask - # Calculate actual flops with new masks - self.current_flops = self._calculate_flops_pruned_model_by_masks() + # Calculate actual flops and weights number with new masks + self._update_benchmark_statistics() def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, target_flops_pruning_rate: float) -> None: @@ -518,15 +523,16 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, for node_id in next_nodes: tmp_in_channels[node_id] -= 1 - flops = sum(count_flops_for_nodes(self._model.get_original_graph(), - self._modules_in_shapes, - self._modules_out_shapes, - input_channels=tmp_in_channels, - output_channels=tmp_out_channels, - conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, - linear_op_metatypes=LINEAR_LAYER_METATYPES).values()) + flops, params_num = count_flops_and_weights(self._model.get_original_graph(), + self._modules_in_shapes, + self._modules_out_shapes, + input_channels=tmp_in_channels, + output_channels=tmp_out_channels, + conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, + linear_op_metatypes=LINEAR_LAYER_METATYPES) if flops < target_flops: self.current_flops = flops + self.current_params_num = params_num return cur_num += 1 raise RuntimeError("Can't prune model to asked flops pruning rate") @@ -634,6 +640,9 @@ def compression_stage(self) -> CompressionStage: return CompressionStage.FULLY_COMPRESSED return CompressionStage.PARTIALLY_COMPRESSED + def _update_benchmark_statistics(self): + self.current_flops, self.current_params_num = self._calculate_flops_and_weights_pruned_model_by_masks() + def _run_batchnorm_adaptation(self): if self._bn_adaptation is None: self._bn_adaptation = BatchnormAdaptationAlgorithm(**extract_bn_adaptation_init_params(self.config)) diff --git a/tests/torch/pruning/filter_pruning/test_algo.py b/tests/torch/pruning/filter_pruning/test_algo.py index 87ec314a518..06d427e817a 100644 --- a/tests/torch/pruning/filter_pruning/test_algo.py +++ b/tests/torch/pruning/filter_pruning/test_algo.py @@ -313,23 +313,21 @@ def test_zeroing_gradients(zero_grad): assert torch.allclose(masked_grad, grad) -@pytest.mark.parametrize(('all_weights', 'pruning_flops_target', 'ref_flops'), +@pytest.mark.parametrize(('all_weights', 'pruning_flops_target', 'ref_flops', 'ref_params_num'), [ - (False, None, 1315008), - (True, None, 1492400), - (False, 0.5, 2367952), - (True, 0.5, 2380268), + (False, None, 1315008, 7776), + (True, None, 1492400, 9304), + (False, 0.5, 2367952, 13160), + (True, 0.5, 2380268, 13678), ] ) -def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops): +def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops, ref_params_num): """ Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together). :param all_weights: whether mask will be calculated for all weights in common or not :param pruning_flops_target: prune model by flops, if None then by number of channels :param ref_flops: reference size of model """ - - config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['params']['all_weights'] = all_weights config['compression']['pruning_init'] = 0.5 @@ -339,8 +337,9 @@ def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops): _, pruning_algo, _ = create_pruning_algo_with_config(config) assert pruning_algo.current_flops == ref_flops + assert pruning_algo.current_params_num == ref_params_num # pylint:disable=protected-access - assert pruning_algo._calculate_flops_pruned_model_by_masks() == ref_flops + assert pruning_algo._calculate_flops_and_weights_pruned_model_by_masks() == (ref_flops, ref_params_num) def test_clusters_for_multiple_forward():