Skip to content

Commit

Permalink
Filter pruning by flops (openvinotoolkit#657)
Browse files Browse the repository at this point in the history
* take into account N+ dimensional input to convolutions

* take into account N+ dimensional input to convolutions

* pruning by flops

* correct work with shared nodes. Initialize input shape in retinanet

* apply comments

* correct work with sequential models

* pylint

* add docstrings, apply comments

* fix merge
  • Loading branch information
evgeniya-egupova committed Apr 16, 2021
1 parent e2e3cda commit d26d961
Show file tree
Hide file tree
Showing 14 changed files with 556 additions and 182 deletions.
16 changes: 11 additions & 5 deletions beta/nncf/tensorflow/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.module_attributes import ConvolutionModuleAttributes
from beta.nncf.tensorflow.layers.common import GENERAL_CONV_LAYERS
from beta.nncf.tensorflow.layers.wrapper import NNCFWrapper
from beta.nncf.tensorflow.graph.utils import get_expanded_node_name
from beta.nncf.tensorflow.graph.utils import is_functional_model
from beta.nncf.tensorflow.graph.utils import is_sequential_model
from beta.nncf.tensorflow.layers.common import GENERAL_CONV_LAYERS
from beta.nncf.tensorflow.layers.data_layout import get_input_channel_axis
from beta.nncf.tensorflow.layers.wrapper import NNCFWrapper
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.module_attributes import ConvolutionModuleAttributes


def convert_keras_model_to_nxmodel(model):
Expand Down Expand Up @@ -176,11 +177,13 @@ def _get_nncf_graph_from_functional(model: tf.keras.Model) -> NNCFGraph:
raw_nodes = _prepare_raw_nodes(model)
return _get_nncf_graph_from_raw_nodes(model_config, raw_nodes)


def _prepare_shape(shape):
if not isinstance(shape, list):
return [shape]
return shape


def _prepare_raw_nodes(model: tf.keras.Model) -> Dict:
model_config = model.get_config()
raw_nodes = Dict()
Expand Down Expand Up @@ -299,16 +302,19 @@ def _get_nncf_graph_from_sequential(model: tf.keras.Model) -> NNCFGraph:


def _get_module_attributes(layer: tf.keras.layers.Layer, attrs: dict) -> ConvolutionModuleAttributes:
channel_axis = -1 if attrs['data_format'] == 'channels_last' else 1
channel_axis = get_input_channel_axis(layer)
if isinstance(layer, NNCFWrapper):
strides = layer.layer.strides[0]
groups = layer.layer.groups
kernel_size = layer.layer.kernel_size
else:
strides = layer.strides[0]
groups = layer.groups
kernel_size = layer.kernel_size

return ConvolutionModuleAttributes(layer.trainable,
layer.get_input_shape_at(0)[channel_axis],
layer.get_output_shape_at(0)[channel_axis],
kernel_size,
strides,
groups)
4 changes: 4 additions & 0 deletions beta/nncf/tensorflow/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def get_original_name_and_instance_index(node_name):
return original_name, instance_index


def get_original_name(node_name):
return get_original_name_and_instance_index(node_name)[0]


def get_layer_to_graph_nodes_map(model, node_names):
layer_to_nodes_map = {layer.name: {'type': layer.__class__.__name__,
'nodes': []}
Expand Down
4 changes: 4 additions & 0 deletions beta/nncf/tensorflow/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
'Conv3DTranspose'
]

LINEAR_LAYERS = [
'Dense'
]

KERAS_LAYERS_AGNOSTIC_TO_DATA_PRECISION_WITH_ONE_INPUT = [
'Cropping1D',
'Cropping2D',
Expand Down
17 changes: 10 additions & 7 deletions beta/nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@


class PrunedLayerInfo:
def __init__(self, layer_name: str):
def __init__(self, layer_name: str, node_id: int):
self.layer_name = layer_name
self.nncf_node_id = node_id
self.key = self.layer_name


class BasePruningAlgoBuilder(TFCompressionAlgorithmBuilder):
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, config):
self._prune_downsample_convs)

self._pruned_layer_groups_info = None
self._graph = None
self._op_names = []

def apply_to(self, model: tf.keras.Model) -> tf.keras.Model:
Expand All @@ -105,8 +108,8 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
:return: The instance of the `TransformationLayout` class containing
a list of pruning mask insertions.
"""
graph = convert_keras_model_to_nncf_graph(model)
groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(graph)
self._graph = convert_keras_model_to_nncf_graph(model)
groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

transformations = TFTransformationLayout()
shared_layers = set()
Expand Down Expand Up @@ -136,13 +139,13 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations.register(
self._get_insertion_command_binary_mask(layer_name, attr_name)
)
group_minfos.append(PrunedLayerInfo(layer_name))
group_minfos.append(PrunedLayerInfo(layer_name, node.node_id))

cluster = NodesCluster(i, group_minfos, [n.node_id for n in group.nodes])
self._pruned_layer_groups_info.add_cluster(cluster)

# Propagating masks across the graph to detect spec_nodes that will be pruned
mask_propagator = MaskPropagationAlgorithm(graph, TF_PRUNING_OPERATOR_METATYPES)
mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES)
mask_propagator.mask_propagation()

# Add masks for all spec modules, because prunable batchnorm layers can be determines
Expand All @@ -151,7 +154,7 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
if not self._prune_batch_norms:
types_spec_layers.remove('BatchNormalization')

spec_nodes = graph.get_nodes_by_types(types_spec_layers)
spec_nodes = self._graph.get_nodes_by_types(types_spec_layers)
for spec_node in spec_nodes:
layer_name = get_layer_identifier(spec_node)
if spec_node.data['output_mask'] is None:
Expand Down Expand Up @@ -269,7 +272,7 @@ def _check_pruning_rate(self, params):
if pruning_target and pruning_flops_target:
raise ValueError('Only one parameter from \'pruning_target\' and \'pruning_flops_target\' can be set.')
if pruning_flops_target:
raise Exception('Pruning by flops is not supported in NNCF TensorFlow yet.')
self.prune_flops = True

def statistics(self, quickly_collected_only=False) -> Dict[str, object]:
raw_pruning_statistics = self.raw_statistics()
Expand Down
Loading

0 comments on commit d26d961

Please sign in to comment.