Skip to content

Commit

Permalink
Add maximum pruning rate statistics (openvinotoolkit#931)
Browse files Browse the repository at this point in the history
Add maximum pruning rate statistics/count of prunable layers which was actually pruned
  • Loading branch information
daniil-lyakhov committed Sep 30, 2021
1 parent e83fa5f commit 4bf9920
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
62 changes: 58 additions & 4 deletions nncf/common/pruning/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def __init__(self,
Initializes statistics of the filter pruning algorithm.
:param model_statistics: Statistics of the pruned model.
:param full_flops: Full FLOPS.
:param current_flops: Current FLOPS.
:param full_params_num: Full number of weights.
:param current_params_num: Current number of weights.
:param full_flops: The total amount of FLOPS in the model.
:param current_flops: Current amount of FLOPS in the model.
:param full_params_num: The total amount of weights in the model.
:param current_params_num: Current amount of weights in the model.
:param target_pruning_level: A target level of the pruning
for the algorithm for the current epoch.
"""
Expand Down Expand Up @@ -133,3 +133,57 @@ def to_str(self) -> str:
f'Statistics of the filter pruning algorithm:\n{algorithm_string}'
)
return pretty_string


class PrunedModelTheoreticalBorderline(Statistics):
"""
Contains theoretical borderline statistics of the filter pruning algorithm.
"""

def __init__(self,
num_pruned_layers: int,
num_prunable_layers: int,
max_prunable_flops: float,
max_prunable_params: float,
total_flops: int,
total_params: int):
"""
Initializes statistics of the filter pruning theoretical borderline.
:param num_pruned_layers: Number of layers which was actually
pruned.
:param num_prunable_layers: Number of layers which have
prunable type.
:param max_prunable_flops: Number of flops for pruned
model with pruning rate = 1.
:param max_prunable_params: Number of weights for pruned
model with pruning rate = 1.
:param total_flops: The total amount of FLOPS in the model.
:param total_params: The total amount of weights in the model.
"""
self._giga = 1e9
self._mega = 1e6
self.pruned_layers_num = num_pruned_layers
self.prunable_layers_num = num_prunable_layers
self.minimum_possible_flops = max_prunable_flops
self.minimum_possible_params = max_prunable_params
self.total_flops = total_flops
self.total_params = total_params

def to_str(self) -> str:
algorithm_string = create_table(
header=['Statistic\'s name', 'Value'],
rows=[
['Pruned layers count / prunable layers count', f'{self.pruned_layers_num} /'
f' {self.prunable_layers_num}'],
['GFLOPS minimum possible after pruning / total', f'{self.minimum_possible_flops / self._giga:.3f} /'
f' {self.total_flops / self._giga:.3f}'],
['MParams minimum possible after pruning / total', f'{self.minimum_possible_params / self._mega:.3f} /'
f' {self.total_params / self._mega:.3f}'],
]
)

pretty_string = (
f'Theoretical borderline of the filter pruning algorithm\nfor current model:\n{algorithm_string}'
)
return pretty_string
13 changes: 13 additions & 0 deletions nncf/tensorflow/pruning/filter_pruning/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from nncf.common.pruning.utils import get_conv_in_out_channels
from nncf.common.pruning.utils import get_rounded_pruned_element_number
from nncf.common.statistics import NNCFStatistics
from nncf.common.pruning.statistics import PrunedModelTheoreticalBorderline
from nncf.common.utils.debug import is_debug
from nncf.common.utils.logger import logger as nncf_logger
from nncf.common.schedulers import StubCompressionScheduler
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
Expand Down Expand Up @@ -126,6 +128,10 @@ def __init__(self,
self.current_flops = self.full_flops
self.full_params_num = sum(self._nodes_params_num.values())
self.current_params_num = self.full_params_num
self._pruned_layers_num = len(self._pruned_layer_groups_info.get_all_nodes())
self._prunable_layers_num = len(self._original_graph.get_nodes_by_types(self._prunable_types))
self._max_prunable_flops, self._max_prunable_params = \
self._calculate_flops_and_weights_in_uniformly_pruned_model(1.)

self._weights_normalizer = tensor_l2_normalizer # for all weights in common case
self._filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(params.get('filter_importance', 'L2'))
Expand Down Expand Up @@ -162,6 +168,13 @@ def disable_scheduler(self):
self._scheduler.current_pruning_level = 0.0

def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
if not quickly_collected_only and is_debug():
stats = PrunedModelTheoreticalBorderline(
self._pruned_layers_num, self._prunable_layers_num, self._max_prunable_flops,
self._max_prunable_params, self.full_flops, self.full_params_num)

nncf_logger.debug(stats.to_str())

model_statistics = self._calculate_pruned_model_stats()
self._update_benchmark_statistics()
target_pruning_level = self.scheduler.current_pruning_level
Expand Down
14 changes: 14 additions & 0 deletions nncf/torch/pruning/filter_pruning/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm
from nncf.common.pruning.schedulers import PRUNING_SCHEDULERS
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.pruning.statistics import PrunedModelTheoreticalBorderline
from nncf.common.pruning.statistics import PrunedLayerSummary
from nncf.common.pruning.statistics import PrunedModelStatistics
from nncf.common.pruning.utils import calculate_in_out_channels_in_uniformly_pruned_model
Expand All @@ -41,6 +42,7 @@
from nncf.common.pruning.utils import get_rounded_pruned_element_number
from nncf.common.schedulers import StubCompressionScheduler
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.debug import is_debug
from nncf.common.utils.logger import logger as nncf_logger
from nncf.config.extractors import extract_bn_adaptation_init_params
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
Expand Down Expand Up @@ -118,6 +120,7 @@ def __init__(self, target_model: NNCFNetwork,
prunable_types: List[str],
pruned_module_groups: Clusterization[PrunedModuleInfo],
config: NNCFConfig):
#pylint:disable=too-many-statements
super().__init__(target_model, prunable_types, pruned_module_groups, config)
params = self.pruning_config.get('params', {})
self.frozen = False
Expand All @@ -137,6 +140,10 @@ def __init__(self, target_model: NNCFNetwork,
self.current_flops = self.full_flops
self.full_params_num = sum(self.nodes_params_num.values())
self.current_params_num = self.full_params_num
self._pruned_layers_num = len(self.pruned_module_groups_info.get_all_nodes())
self._prunable_layers_num = len(self._model.get_graph().get_nodes_by_types(self._prunable_types))
self._max_prunable_flops, self._max_prunable_params =\
self._calculate_flops_and_weights_in_uniformly_pruned_model(1.)

self.weights_normalizer = tensor_l2_normalizer # for all weights in common case
self.filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(params.get('filter_importance', 'L2'))
Expand Down Expand Up @@ -206,6 +213,13 @@ def set_mask(minfo: PrunedModuleInfo, mask: torch.Tensor) -> None:
minfo.operand.binary_filter_pruning_mask = mask

def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
if not quickly_collected_only and is_debug():
stats = PrunedModelTheoreticalBorderline(
self._pruned_layers_num, self._prunable_layers_num, self._max_prunable_flops,
self._max_prunable_params, self.full_flops, self.full_params_num)

nncf_logger.debug(stats.to_str())

pruned_layers_summary = {}
for minfo in self.pruned_module_groups_info.get_all_nodes():
layer_name = str(minfo.module_scope)
Expand Down

0 comments on commit 4bf9920

Please sign in to comment.