Skip to content

Commit

Permalink
Correct statistics (filter pruning and sparsity) (#1098)
Browse files Browse the repository at this point in the history
Changes

Statistics of the pruned model are now collected from the model (using get_nncf_operations) but not from the graph (affected when loading from a checkpoint, the graph is not getting updated, but the model is).

When loading from a checkpoint the current_level of Scheduler is recalculated.

current_step and current_epoch attributes of BaseCompressionScheduler become private (just getters are available - not possible to update from outside). Tests are updated in accordance (no direct set of current_step and current_epoch).
Reason for changes

Printed statistics of the model and algorithm are not correct while loading from the checkpoint.
Tests

Updated.
  • Loading branch information
negvet committed Mar 15, 2022
1 parent 5cdcb3b commit d2c0473
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 80 deletions.
4 changes: 1 addition & 3 deletions nncf/common/pruning/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(self, controller, params: dict):
self.num_warmup_epochs = params.get('num_init_steps', 0)
self.num_pruning_epochs = params.get('pruning_steps', 100)
self.freeze_epoch = self.num_warmup_epochs + self.num_pruning_epochs
self._current_level = self.initial_level

def _calculate_pruning_level(self) -> float:
"""
Expand All @@ -81,7 +80,6 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None:
will update the state of the pruning method.
"""
super().epoch_step(next_epoch)
self._current_level = self._calculate_pruning_level()
self._controller.set_pruning_level(self.current_pruning_level)
if self.current_epoch >= self.freeze_epoch:
self._controller.freeze()
Expand All @@ -105,7 +103,7 @@ def current_pruning_level(self) -> float:
:return: Current sparsity level.
"""
if self.current_epoch >= self.num_warmup_epochs:
return self._current_level
return self._calculate_pruning_level()
return 0


Expand Down
10 changes: 3 additions & 7 deletions nncf/common/pruning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor
from nncf.common.pruning.clusterization import Cluster
from nncf.common.pruning.clusterization import Clusterization
from nncf.common.pruning.structs import PrunedLayerInfoBase
Expand Down Expand Up @@ -386,8 +385,7 @@ def get_num_of_sparse_elements_by_node(node_name: str) -> int:


def calculate_in_out_channels_by_masks(pruning_groups: List[Cluster[PrunedLayerInfoBase]],
masks: Dict[str, NNCFTensor],
tensor_processor: Type[NNCFCollectorTensorProcessor],
num_of_sparse_elements_by_node: Dict[NNCFNodeName, int],
full_input_channels: Dict[str, int],
full_output_channels: Dict[str, int],
pruning_groups_next_nodes: Dict[int, List[str]]) -> Tuple[Dict[str, int],
Expand All @@ -397,16 +395,14 @@ def calculate_in_out_channels_by_masks(pruning_groups: List[Cluster[PrunedLayerI
and updating corresponding input channels number in `pruning_groups_next_nodes` nodes.
:param pruning_groups: A list of pruning groups.
:param masks: A dictionary of masks of each pruning node.
:param tensor_processor: NNCF Tensor processor to operate on NNCFTensors.
:param num_of_sparse_elements_by_node: A dictionary of num_of_sparse_elements of each pruning node.
:param full_input_channels: A dictionary of input channels number in original model.
:param full_output_channels: A dictionary of output channels number in original model.
:param pruning_groups_next_nodes: A dictionary of next nodes of each pruning group.
:return Dictionary of new input channels number {node_name: channels_num}
"""
def get_num_of_sparse_elements_by_node(node_name: str) -> int:
mask = masks[node_name]
return mask.shape[0] - int(tensor_processor.sum(mask))
return num_of_sparse_elements_by_node[node_name]

return _calculate_in_out_channels(pruning_groups,
get_num_of_sparse_elements_by_node,
Expand Down
38 changes: 28 additions & 10 deletions nncf/common/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,26 @@ def __init__(self):
update the compression method state taking into account the training step,
there is the same for current_epoch is -1.
"""
self.current_step = -1
self.current_epoch = -1
self._current_step = -1
self._current_epoch = -1

@property
def current_step(self) -> int:
"""
Return current step.
:return: Current step.
"""
return self._current_step

@property
def current_epoch(self) -> int:
"""
Return current epoch.
:return: Current epoch.
"""
return self._current_epoch

def step(self, next_step: Optional[int] = None) -> None:
"""
Expand All @@ -194,8 +212,8 @@ def step(self, next_step: Optional[int] = None) -> None:
will update the state of the compression method.
"""
if next_step is None:
next_step = self.current_step + 1
self.current_step = next_step
next_step = self._current_step + 1
self._current_step = next_step

def epoch_step(self, next_epoch: Optional[int] = None) -> None:
"""
Expand All @@ -206,8 +224,8 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None:
will update the state of the compression method.
"""
if next_epoch is None:
next_epoch = self.current_epoch + 1
self.current_epoch = next_epoch
next_epoch = self._current_epoch + 1
self._current_epoch = next_epoch

def load_state(self, state: Dict[str, Any]) -> None:
"""
Expand All @@ -216,8 +234,8 @@ def load_state(self, state: Dict[str, Any]) -> None:
:param state: Output of `get_state()` method.
"""
self.current_step = state['current_step']
self.current_epoch = state['current_epoch']
self._current_step = state['current_step']
self._current_epoch = state['current_epoch']

def get_state(self) -> Dict[str, Any]:
"""
Expand All @@ -226,8 +244,8 @@ def get_state(self) -> Dict[str, Any]:
:return: The compression scheduler state.
"""
return {
'current_step': self.current_step,
'current_epoch': self.current_epoch
'current_step': self._current_step,
'current_epoch': self._current_epoch
}


Expand Down
36 changes: 30 additions & 6 deletions nncf/common/sparsity/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(self, controller: SparsityController, params: dict):
self.target_level = params.get('sparsity_target', 0.5)
self.target_epoch = params.get('sparsity_target_epoch', 90)
self.freeze_epoch = params.get('sparsity_freeze_epoch', 100)
self._current_level = self.initial_level

def _calculate_sparsity_level(self) -> float:
"""
Expand All @@ -74,10 +73,9 @@ def _update_sparsity_level(self) -> None:
Calculates the current sparsity level and updates the internal
state of the `controller`.
"""
self._current_level = self._calculate_sparsity_level()
if self.current_epoch >= self.freeze_epoch:
self._controller.freeze()
self._controller.set_sparsity_level(self._current_level)
self._controller.set_sparsity_level(self._calculate_sparsity_level())

@property
def current_sparsity_level(self) -> float:
Expand All @@ -87,7 +85,9 @@ def current_sparsity_level(self) -> float:
:return: Current sparsity level.
"""
return self._current_level
if self._current_epoch == -1:
return self.initial_level
return self._calculate_sparsity_level()


@SPARSITY_SCHEDULERS.register('polynomial')
Expand Down Expand Up @@ -232,6 +232,17 @@ def __init__(self, controller: SparsityController, params: dict):
self.eps = params.get('eps', 0.03)
self.patience = params.get('patience', 1)
self.num_bad_epochs = 0
self._current_level = self.initial_level

@property
def current_sparsity_level(self) -> float:
"""
Returns sparsity level for the `current_epoch` or for step
in the `current_epoch`.
:return: Current sparsity level.
"""
return self._current_level

def epoch_step(self, next_epoch: Optional[int] = None) -> None:
super().epoch_step(next_epoch)
Expand All @@ -246,7 +257,9 @@ def _calculate_sparsity_level(self) -> float:
self.num_bad_epochs = 0
current_level = current_level + self.decay_step

return min(current_level, self.target_level)
self._current_level = min(current_level, self.target_level)

return self._current_level

def load_state(self, state: Dict[str, Any]) -> None:
super().load_state(state)
Expand Down Expand Up @@ -277,7 +290,18 @@ def __init__(self, controller: SparsityController, params: dict):
self.schedule = MultiStepSchedule(
sorted(params.get('multistep_steps', [90])), params.get('multistep_sparsity_levels', [0.1, 0.5]))
self.target_level = self.schedule.values[-1]
self._current_level = self.schedule.values[0]

@property
def current_sparsity_level(self) -> float:
"""
Returns sparsity level for the `current_epoch` or for step
in the `current_epoch`.
:return: Current sparsity level.
"""
if self._current_epoch == -1:
return self.schedule.values[0]
return self._calculate_sparsity_level()

def epoch_step(self, next_epoch: Optional[int] = None) -> None:
super().epoch_step(next_epoch)
Expand Down
50 changes: 31 additions & 19 deletions nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
limitations under the License.
"""

from typing import Dict
from typing import List
from typing import Tuple

Expand Down Expand Up @@ -40,7 +41,7 @@
from nncf.tensorflow.graph.transformations.commands import TFInsertionCommand
from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout
from nncf.tensorflow.graph.utils import get_layer_identifier
from nncf.tensorflow.graph.utils import collect_wrapped_layers
from nncf.tensorflow.graph.utils import get_nncf_operations
from nncf.tensorflow.tensor import TFNNCFTensor
from nncf.tensorflow.pruning.tensor_processor import TFNNCFPruningTensorProcessor
from nncf.tensorflow.pruning.operations import TFElementwisePruningOp
Expand Down Expand Up @@ -280,6 +281,7 @@ def __init__(self,
self._pruned_layer_groups_info = pruned_layer_groups_info
self.prune_flops = False
self._check_pruning_level(params)
self._num_of_sparse_elements_by_node = None

def freeze(self):
raise NotImplementedError
Expand All @@ -301,29 +303,39 @@ def _check_pruning_level(self, params):
if pruning_flops_target:
self.prune_flops = True

def _calculate_num_of_sparse_elements_by_node(self) -> Dict[NNCFNodeName, int]:
"""Returns the number of sparse elements per node. Take into account names ('^') for the shared ops."""
if self._num_of_sparse_elements_by_node is None:
self._calculate_pruned_layers_summary()

retval = {}
for group in self._pruned_layer_groups_info.get_all_clusters():
for node in group.elements:
retval[node.node_name] = self._num_of_sparse_elements_by_node[node.layer_name]
return retval

def _calculate_pruned_layers_summary(self) -> List[PrunedLayerSummary]:
pruning_levels = []
mask_names = []
weights_shapes = []
mask_shapes = []
wrapped_layers = collect_wrapped_layers(self._model)
for wrapped_layer in wrapped_layers:
for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
for op_name in ops:
if op_name in self._op_names:
mask = wrapped_layer.ops_weights[op_name]['mask']
mask_names.append(mask.name)
weights_shapes.append(list(mask.shape))
reduce_axes = list(range(len(mask.shape)))
filter_axis = get_filter_axis(wrapped_layer, weight_attr)
if filter_axis == -1:
filter_axis = reduce_axes[filter_axis]
reduce_axes.remove(filter_axis)
filter_mask = tf.reduce_max(tf.cast(mask, tf.int32), axis=reduce_axes, keepdims=True)
mask_shapes.append(list(filter_mask.shape))
filters_number = get_filters_num(wrapped_layer)
pruned_filters_number = filters_number - tf.reduce_sum(filter_mask)
pruning_levels.append(pruned_filters_number / filters_number)
self._num_of_sparse_elements_by_node = {}
for wrapped_layer, weight_attr, op_name in get_nncf_operations(self._model, self._op_names):
mask = wrapped_layer.ops_weights[op_name.name]['mask']
mask_names.append(mask.name)
weights_shapes.append(list(mask.shape))
reduce_axes = list(range(len(mask.shape)))
filter_axis = get_filter_axis(wrapped_layer, weight_attr)
if filter_axis == -1:
filter_axis = reduce_axes[filter_axis]
reduce_axes.remove(filter_axis)
filter_mask = tf.reduce_max(tf.cast(mask, tf.int32), axis=reduce_axes, keepdims=True)
mask_shapes.append(list(filter_mask.shape))
filters_number = get_filters_num(wrapped_layer)
pruned_filters_number = filters_number - tf.reduce_sum(filter_mask)
pruning_levels.append(pruned_filters_number / filters_number)
pruned_filter_number = filters_number - tf.reduce_sum(filter_mask)
self._num_of_sparse_elements_by_node[wrapped_layer.name] = pruned_filter_number.numpy()

pruning_levels = tf.keras.backend.batch_get_value(pruning_levels)
mask_pruning = list(zip(mask_names, weights_shapes, mask_shapes, pruning_levels))
Expand Down
11 changes: 1 addition & 10 deletions nncf/tensorflow/pruning/filter_pruning/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from nncf.tensorflow.layers.data_layout import get_input_channel_axis
from nncf.tensorflow.layers.wrapper import NNCFWrapper
from nncf.tensorflow.loss import TFZeroCompressionLoss
from nncf.tensorflow.tensor_statistics.collectors import TFNNCFCollectorTensorProcessor
from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoBuilder
from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController
from nncf.tensorflow.pruning.base_algorithm import PrunedLayerInfo
Expand Down Expand Up @@ -526,18 +525,10 @@ def _calculate_filters_importance_in_group(self, group: Cluster[PrunedLayerInfo]

return cumulative_filters_importance

def _collect_pruning_masks(self) -> Dict[str, TFNNCFTensor]:
retval = {}
for group in self._pruned_layer_groups_info.get_all_clusters():
for node in group.elements:
retval[node.node_name] = self._original_graph.get_node_by_name(node.node_name).data['output_mask']
return retval

def _update_benchmark_statistics(self):
tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
pruning_groups=self._pruned_layer_groups_info.get_all_clusters(),
masks=self._collect_pruning_masks(),
tensor_processor=TFNNCFCollectorTensorProcessor,
num_of_sparse_elements_by_node=self._calculate_num_of_sparse_elements_by_node(),
full_input_channels=self._layers_in_channels,
full_output_channels=self._layers_out_channels,
pruning_groups_next_nodes=self._next_nodes)
Expand Down
17 changes: 7 additions & 10 deletions nncf/torch/pruning/filter_pruning/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@
from nncf.common.utils.debug import is_debug
from nncf.common.utils.logger import logger as nncf_logger
from nncf.config.extractors import extract_bn_adaptation_init_params
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.graph.operator_metatypes import PTConv1dMetatype
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTConv3dMetatype
Expand Down Expand Up @@ -682,18 +680,17 @@ def disable_scheduler(self):
self._scheduler = StubCompressionScheduler()
self._scheduler.current_pruning_level = 0.0

def _collect_pruning_masks(self) -> Dict[str, PTNNCFTensor]:
retval = {}
for group in self.pruned_module_groups_info.get_all_clusters():
for node in group.elements:
retval[node.node_name] = PTNNCFTensor(node.operand.binary_filter_pruning_mask)
return retval
def _calculate_num_of_sparse_elements_by_node(self) -> Dict[str, int]:
num_of_sparse_elements_by_node = {}
for minfo in self.pruned_module_groups_info.get_all_nodes():
mask = self.get_mask(minfo)
num_of_sparse_elements_by_node[minfo.node_name] = mask.view(-1).size(0) - mask.nonzero().size(0)
return num_of_sparse_elements_by_node

def _update_benchmark_statistics(self):
tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
pruning_groups=self.pruned_module_groups_info.get_all_clusters(),
masks=self._collect_pruning_masks(),
tensor_processor=PTNNCFCollectorTensorProcessor,
num_of_sparse_elements_by_node=self._calculate_num_of_sparse_elements_by_node(),
full_input_channels=self._modules_in_channels,
full_output_channels=self._modules_out_channels,
pruning_groups_next_nodes=self.next_nodes)
Expand Down
4 changes: 1 addition & 3 deletions tests/tensorflow/pruning/test_flops_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from nncf.common.pruning.utils import calculate_in_out_channels_by_masks
from nncf.common.pruning.utils import count_flops_and_weights
from nncf.tensorflow.tensor_statistics.collectors import TFNNCFCollectorTensorProcessor
from nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES
from nncf.tensorflow.graph.metatypes.common import LINEAR_LAYER_METATYPES
from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test
Expand Down Expand Up @@ -56,8 +55,7 @@ def test_flops_calulation_for_spec_layers(model_fn, all_weights, pruning_flops_t
# pylint:disable=protected-access
tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
pruning_groups=compression_ctrl._pruned_layer_groups_info.get_all_clusters(),
masks=compression_ctrl._collect_pruning_masks(),
tensor_processor=TFNNCFCollectorTensorProcessor,
num_of_sparse_elements_by_node=compression_ctrl._calculate_num_of_sparse_elements_by_node(),
full_input_channels=compression_ctrl._layers_in_channels,
full_output_channels=compression_ctrl._layers_out_channels,
pruning_groups_next_nodes=compression_ctrl._next_nodes)
Expand Down
3 changes: 1 addition & 2 deletions tests/tensorflow/sparsity/magnitude/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def test_compression_controller_state():
_, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

# Test get state
compression_ctrl.scheduler.current_step = 100
compression_ctrl.scheduler.current_epoch = 5
compression_ctrl.scheduler.load_state({'current_step': 100, 'current_epoch': 5})
state_content = compression_ctrl.get_state()[algo_name]
assert state_content[CtrlStateNames.SCHEDULER] == {'current_step': 100, 'current_epoch': 5}

Expand Down
Loading

0 comments on commit d2c0473

Please sign in to comment.