Skip to content

Commit

Permalink
Common collector of statistics for the sparse model (openvinotoolkit#756
Browse files Browse the repository at this point in the history
)

* [PT] Collector of statistics for the sparse model was added

* [TF] Collector of statistics for the sparse model was added

* Minor updates
  • Loading branch information
andrey-churkin committed Jun 7, 2021
1 parent 139c390 commit 222fd33
Show file tree
Hide file tree
Showing 14 changed files with 396 additions and 189 deletions.
22 changes: 22 additions & 0 deletions beta/nncf/tensorflow/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
limitations under the License.
"""

from typing import List, Tuple
import sys
import inspect

Expand All @@ -19,6 +20,7 @@
from beta.nncf.tensorflow.graph.metatypes.keras_layers import TFNNCFWrapperLayerMetatype
from beta.nncf.tensorflow.graph.metatypes.matcher import get_keras_layer_metatype
from beta.nncf.tensorflow.layers.wrapper import NNCFWrapper
from beta.nncf.tensorflow.layers.operation import NNCFOperation
from nncf.common.graph import NNCFNode

SHARED_OPERATION_MARK = '^'
Expand Down Expand Up @@ -137,8 +139,28 @@ def get_layer_identifier(node: NNCFNode):
layer_name, _ = get_original_name_and_instance_index(node.node_name)
return layer_name


def unwrap_layer(layer):
layer_metatype = get_keras_layer_metatype(layer, determine_subtype=False)
if layer_metatype == TFNNCFWrapperLayerMetatype:
return layer.layer
return layer


def get_nncf_operations(model: tf.keras.Model, operation_names: List[str]) -> Tuple[NNCFWrapper, str, NNCFOperation]:
"""
Yields the operations from the model which names in `operation_names`.
:param model: Wrapped model.
:param operation_names: List of operation names.
:return: A tuple (wrapped_layer, weight_attr, op) where
- wrapped_layer: A wrapped layer, which contains operation weights.
- weight_attr: A name of the attribute of the wrapped layer to which
the operation is applied.
- op: NNCF operation.
"""
for wrapped_layer in collect_wrapped_layers(model):
for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
for op in ops.values():
if op.name in operation_names:
yield wrapped_layer, weight_attr, op
88 changes: 88 additions & 0 deletions beta/nncf/tensorflow/sparsity/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Copyright (c) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import List

import tensorflow as tf

from nncf.common.sparsity.collector import WeightDescription
from nncf.common.sparsity.collector import BaseSparseModelStatisticsCollector
from beta.nncf.tensorflow.sparsity.magnitude.functions import apply_mask
from beta.nncf.tensorflow.sparsity.magnitude.operation import BinaryMaskWithWeightsBackup
from beta.nncf.tensorflow.graph.utils import get_nncf_operations


def _get_standardized_weight_shape(shape):
return [0 if x is None else x for x in shape]


class TFSparseModelStatisticsCollector(BaseSparseModelStatisticsCollector):
"""
Collects statistics for the sparse tf.keras.Model.
"""

def __init__(self, model: tf.keras.Model, operation_names: List[str]):
"""
Initializes statistics collector of the sparse tf.keras.Model.
:param model: Sparse model.
:param operation_names: Names of operations.
"""
self._model = model
self._operation_names = operation_names
self._excluded_names = []
self._sw_name_to_num_nonzero_map = {}

def _collect_weights_descriptions(self) -> List[WeightDescription]:
weights_descriptions = []
excluded_names = []

# Collect description for sparse weights i.e. weights for which
# sparsity algorithm was applied.
for wrapped_layer, weight_attr, op in get_nncf_operations(self._model, self._operation_names):
weight = wrapped_layer.layer_weights[weight_attr]
operation_weights = wrapped_layer.get_operation_weights(op.name)
binary_mask = op.get_binary_mask(operation_weights)
sparse_weight = apply_mask(weight, binary_mask)

weights_descriptions.append(
WeightDescription(
weight.name,
_get_standardized_weight_shape(weight.shape.as_list()),
tf.math.count_nonzero(sparse_weight).numpy().item(),
is_sparse=True
)
)

# Exclude this name because it has been processed.
excluded_names.append(weight.name)

# Exclude these names because they were added to the model
# by the sparsity algorithm.
excluded_names.extend([w.name for w in operation_weights.values()])
if isinstance(op, BinaryMaskWithWeightsBackup):
excluded_names.append(op.bkup_var.name)

# Collect descriptions for rest weights.
unique_weights = {id(w): w for w in self._model.weights if w.name not in excluded_names}.values()
for weight in unique_weights:
weights_descriptions.append(
WeightDescription(
weight.name,
_get_standardized_weight_shape(weight.shape.as_list()),
tf.math.count_nonzero(weight).numpy().item(),
is_sparse=False
)
)

return weights_descriptions
62 changes: 10 additions & 52 deletions beta/nncf/tensorflow/sparsity/magnitude/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
"""

import tensorflow as tf
from tensorflow.python.keras.utils.layer_utils import count_params

from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
from nncf.common.sparsity.statistics import SparsifiedLayerSummary
from nncf.common.sparsity.statistics import SparsifiedModelStatistics
from nncf.common.sparsity.statistics import LayerThreshold
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
from nncf.common.statistics import NNCFStatistics
Expand All @@ -42,6 +39,7 @@
from beta.nncf.tensorflow.sparsity.magnitude.functions import WEIGHT_IMPORTANCE_FUNCTIONS
from beta.nncf.tensorflow.sparsity.magnitude.operation import BinaryMask
from beta.nncf.tensorflow.sparsity.magnitude.operation import BinaryMaskWithWeightsBackup
from beta.nncf.tensorflow.sparsity.collector import TFSparseModelStatisticsCollector
from beta.nncf.tensorflow.utils.node import is_ignored


Expand Down Expand Up @@ -189,55 +187,15 @@ def _collect_all_weights(self):
return all_weights

def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
sparsity_levels = []
mask_names = []
weights_shapes = []
weights_numbers = []
total_weights_number = tf.constant(0)
total_sparsified_weights_number = tf.constant(0)
total_bkup_weights_number = tf.constant(0)
wrapped_layers = collect_wrapped_layers(self._model)
for wrapped_layer in wrapped_layers:
for ops in wrapped_layer.weights_attr_ops.values():
for op_name, op in ops.items():
if op_name in self._op_names:
if isinstance(op, BinaryMaskWithWeightsBackup):
total_bkup_weights_number += tf.size(op.bkup_var)
if isinstance(op, BinaryMask):
mask = wrapped_layer.ops_weights[op_name]['mask']
mask_names.append(mask.name)
weights_shapes.append(list(mask.shape))
weights_number = tf.size(mask)
weights_numbers.append(weights_number)
sparsified_weights_number = weights_number - tf.reduce_sum(tf.cast(mask, tf.int32))
sparsity_levels.append(sparsified_weights_number / weights_number)
total_weights_number += weights_number
total_sparsified_weights_number += sparsified_weights_number

sparsity_rate_for_sparsified_modules = (total_sparsified_weights_number / total_weights_number).numpy()
model_weights_number = count_params(self._model.weights) - total_weights_number - total_bkup_weights_number
sparsity_rate_for_model = (total_sparsified_weights_number / model_weights_number).numpy()

sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels)
weights_percentages = [weights_number / total_weights_number * 100
for weights_number in weights_numbers]
weights_percentages = tf.keras.backend.batch_get_value(weights_percentages)
mask_sparsity = list(zip(mask_names, weights_shapes, sparsity_levels, weights_percentages))

sparsified_layers_summary = []
threshold_statistics = []
for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity:
sparsified_layers_summary.append(
SparsifiedLayerSummary(mask_name, weights_shape, sparsity_level, weights_percentage)
)

threshold_statistics.append(LayerThreshold(mask_name, self._threshold))

model_statistics = SparsifiedModelStatistics(sparsity_rate_for_model,
sparsity_rate_for_sparsified_modules,
sparsified_layers_summary)

stats = MagnitudeSparsityStatistics(model_statistics, threshold_statistics)
collector = TFSparseModelStatisticsCollector(self.model, self._op_names)
model_stats = collector.collect()

threshold_stats = []
threshold = self._select_threshold(model_stats.sparsity_level)
for s in model_stats.sparsified_layers_summary:
threshold_stats.append(LayerThreshold(s.name, threshold))

stats = MagnitudeSparsityStatistics(model_stats, threshold_stats)

nncf_stats = NNCFStatistics()
nncf_stats.register('magnitude_sparsity', stats)
Expand Down
10 changes: 10 additions & 0 deletions beta/nncf/tensorflow/sparsity/magnitude/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def build(self, input_shape, input_type, name, layer):
def call(self, inputs, weights, _):
return apply_mask(inputs, weights['mask'])

@staticmethod
def get_binary_mask(op_weights):
"""
Returns binary mask from weights of the operation.
:param op_weights: Weights of the operaton.
:return: Binary mask.
"""
return op_weights['mask']


@NNCF_CUSTOM_OBJECTS.register()
class BinaryMaskWithWeightsBackup(BinaryMask):
Expand Down
65 changes: 15 additions & 50 deletions beta/nncf/tensorflow/sparsity/rb/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,28 @@

from typing import List

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils.layer_utils import count_params

from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
from nncf.common.sparsity.schedulers import SparsityScheduler
from nncf.common.sparsity.statistics import SparsifiedLayerSummary
from nncf.common.sparsity.statistics import SparsifiedModelStatistics
from nncf.common.sparsity.statistics import RBSparsityStatistics
from nncf.common.statistics import NNCFStatistics
from beta.nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from beta.nncf.tensorflow.graph.transformations.commands import TFInsertionCommand
from beta.nncf.tensorflow.graph.transformations.commands import TFLayerWeight
from beta.nncf.tensorflow.graph.utils import collect_wrapped_layers
from beta.nncf.tensorflow.graph.utils import get_original_name_and_instance_index
from beta.nncf.tensorflow.graph.utils import get_nncf_operations
from beta.nncf.tensorflow.graph.converter import convert_keras_model_to_nncf_graph
from beta.nncf.tensorflow.sparsity.base_algorithm import BaseSparsityController
from beta.nncf.tensorflow.sparsity.base_algorithm import SPARSITY_LAYERS
from beta.nncf.tensorflow.sparsity.rb.loss import SparseLoss
from beta.nncf.tensorflow.sparsity.rb.operation import RBSparsifyingWeight
from beta.nncf.tensorflow.sparsity.rb.functions import binary_mask
from beta.nncf.tensorflow.sparsity.utils import apply_fn_to_op_weights
from beta.nncf.tensorflow.sparsity.collector import TFSparseModelStatisticsCollector
from beta.nncf.tensorflow.utils.node import is_ignored


Expand Down Expand Up @@ -124,56 +122,23 @@ def freeze(self):
self._loss.disable()

def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
sparsity_levels = []
mask_names = []
weights_shapes = []
weights_numbers = []
sparse_prob_sum = tf.constant(0.)
total_weights_number = tf.constant(0)
total_sparsified_weights_number = tf.constant(0)
wrapped_layers = collect_wrapped_layers(self._model)
for wrapped_layer in wrapped_layers:
for ops in wrapped_layer.weights_attr_ops.values():
for op_name in ops:
if op_name in self._op_names:
mask = wrapped_layer.ops_weights[op_name]['mask']
sw_loss = tf.reduce_sum(binary_mask(mask))
weights_number = tf.size(mask)
sparsified_weights_number = weights_number - tf.cast(sw_loss, tf.int32)
mask_names.append(wrapped_layer.name + '_rb_mask')
weights_shapes.append(list(mask.shape))
weights_numbers.append(weights_number)
sparsity_levels.append(sparsified_weights_number / weights_number)
sparse_prob_sum += tf.math.reduce_sum(tf.math.sigmoid(mask))
total_weights_number += weights_number
total_sparsified_weights_number += sparsified_weights_number

sparsity_rate_for_sparsified_modules = (total_sparsified_weights_number / total_weights_number).numpy()
model_weights_number = count_params(self._model.weights) - total_weights_number
sparsity_rate_for_model = (total_sparsified_weights_number / model_weights_number).numpy()
mean_sparse_prob = 1.0 - (sparse_prob_sum / tf.cast(total_weights_number, tf.float32)).numpy()

sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels)
weights_percentages = [weights_number / total_weights_number * 100
for weights_number in weights_numbers]
weights_percentages = tf.keras.backend.batch_get_value(weights_percentages)
mask_sparsity = list(zip(mask_names, weights_shapes, sparsity_levels, weights_percentages))

sparsified_layers_summary = []
for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity:
sparsified_layers_summary.append(
SparsifiedLayerSummary(mask_name, weights_shape, sparsity_level, weights_percentage)
)

model_statistics = SparsifiedModelStatistics(sparsity_rate_for_model,
sparsity_rate_for_sparsified_modules,
sparsified_layers_summary)
collector = TFSparseModelStatisticsCollector(self.model, self._op_names)
model_stats = collector.collect()

sparse_prob_sum = 0.0
num_weights = 0
for wrapped_layer, _, op in get_nncf_operations(self.model, self._op_names):
operation_weights = wrapped_layer.get_operation_weights(op.name)
mask = op.get_mask(operation_weights)
sparse_prob_sum += tf.math.reduce_sum(tf.math.sigmoid(mask)).numpy().item()
num_weights += np.prod(mask.shape.as_list()).item()
mean_sparse_prob = 1.0 - (sparse_prob_sum / num_weights)

target_level = self.loss.target_sparsity_rate
# TODO(andrey-churkin): Should be calculated when the distributed mode will be supported
masks_consistency = 1.0

stats = RBSparsityStatistics(model_statistics, masks_consistency, target_level, mean_sparse_prob)
stats = RBSparsityStatistics(model_stats, masks_consistency, target_level, mean_sparse_prob)

nncf_stats = NNCFStatistics()
nncf_stats.register('rb_sparsity', stats)
Expand Down
10 changes: 10 additions & 0 deletions beta/nncf/tensorflow/sparsity/rb/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def get_mask(op_weights):
"""
return op_weights['mask']

@staticmethod
def get_binary_mask(op_weights):
"""
Returns binary mask from weights of the operation.
:param op_weights: Weights of the operaton.
:return: Binary mask.
"""
return binary_mask(op_weights['mask'])

@staticmethod
def get_trainable_weight(op_weights):
"""
Expand Down
Loading

0 comments on commit 222fd33

Please sign in to comment.