Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
andrey-churkin committed May 20, 2021
1 parent ac19dc6 commit b5398b0
Show file tree
Hide file tree
Showing 54 changed files with 380 additions and 264 deletions.
2 changes: 1 addition & 1 deletion beta/examples/tensorflow/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def run(config):

logger.info('evaluation...')
statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())
results = compress_model.evaluate(
validation_dataset,
steps=validation_steps,
Expand Down
4 changes: 2 additions & 2 deletions beta/examples/tensorflow/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def train(train_step, test_step, eval_metric, train_dist_dataset, test_dist_data
logger.info('Validation metric = {}'.format(test_metric_result))

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())
statistics = {
f'compression/statistics/{name}': value for name, value in prepare_for_tensorboard(statistics).items()
}
Expand Down Expand Up @@ -311,7 +311,7 @@ def run(config):
config.print_freq)

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())
metric_result = evaluate(test_step, eval_metric, test_dist_dataset, num_test_batches, config.print_freq)
logger.info('Validation metric = {}'.format(metric_result))

Expand Down
2 changes: 1 addition & 1 deletion beta/examples/tensorflow/segmentation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def run_evaluation(config, eval_timeout=None):
load_checkpoint(checkpoint, config.ckpt_path)

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())
metric_result = evaluate(test_step, eval_metric, test_dist_dataset, num_batches, config.print_freq)
eval_metric.reset_states()
logger.info('Test metric = {}'.format(metric_result))
Expand Down
4 changes: 2 additions & 2 deletions beta/examples/tensorflow/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def train(train_step, train_dist_dataset, initial_epoch, initial_step,
timer.tic()

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())
statistics = {
f'compression/statistics/{name}': value for name, value in prepare_for_tensorboard(statistics).items()
}
Expand Down Expand Up @@ -267,7 +267,7 @@ def run_train(config):

logger.info('Compression statistics')
statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
logger.info(statistics.to_str())


def main(argv):
Expand Down
6 changes: 3 additions & 3 deletions beta/nncf/tensorflow/algorithm_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nncf.common.schedulers import StubCompressionScheduler
from nncf.common.utils.logger import logger
from nncf.common.utils.registry import Registry
from nncf.common.compression import StubStatistics
from nncf.common.statistics import NNCFStatistics
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmController
from beta.nncf.tensorflow.loss import TFZeroCompressionLoss
Expand Down Expand Up @@ -47,8 +47,8 @@ def loss(self) -> TFZeroCompressionLoss:
def scheduler(self) -> StubCompressionScheduler:
return self._scheduler

def statistics(self, quickly_collected_only: bool = False) -> StubStatistics:
return StubStatistics()
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
return NNCFStatistics()


def get_compression_algorithm_builder(config):
Expand Down
4 changes: 2 additions & 2 deletions beta/nncf/tensorflow/api/composite_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import List, Optional, TypeVar

from nncf import NNCFConfig
from nncf.api.composite_compression import CompositeCompressionAlgorithmBuilder
from nncf.api.composite_compression import CompositeCompressionAlgorithmController
from nncf.common.composite_compression import CompositeCompressionAlgorithmBuilder
from nncf.common.composite_compression import CompositeCompressionAlgorithmController
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmController
from beta.nncf.tensorflow.graph.transformations.layout import TFTransformationLayout
Expand Down
14 changes: 7 additions & 7 deletions beta/nncf/tensorflow/callbacks/statistics_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import tensorflow as tf

from nncf.api.compression import Statistics
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.logger import logger as nncf_logger


Expand All @@ -25,14 +25,14 @@ class StatisticsCallback(tf.keras.callbacks.Callback):
"""

def __init__(self,
statistics_fn: Callable[[], Statistics],
statistics_fn: Callable[[], NNCFStatistics],
log_tensorboard: bool = True,
log_text: bool = True,
log_dir: str = None):
"""
Initializes compression statistics callback.
:param statistics_fn: A callable object that provides compression statistics.
:param statistics_fn: A callable object that provides NNCF statistics.
:param log_tensorboard: Whether to log statistics to tensorboard or not.
:param log_text: Whether to log statistics to stdout.
:param log_dir: The directory for tensorbard logging.
Expand All @@ -54,17 +54,17 @@ def _dump_to_tensorboard(self, logs: dict, step: int):
tf.summary.scalar(name, value, step=step)

def on_epoch_end(self, epoch: int, logs: dict = None):
statistics = self._statistics_fn()
nncf_stats = self._statistics_fn()
if self._log_tensorboard:
self._dump_to_tensorboard(self._prepare_for_tensorboard(statistics),
self._dump_to_tensorboard(self._prepare_for_tensorboard(nncf_stats),
self.model.optimizer.iterations.numpy())
if self._log_text:
nncf_logger.info(statistics.as_str())
nncf_logger.info(nncf_stats.to_str())

def on_train_end(self, logs: dict = None):
if self._file_writer:
self._file_writer.close()

def _prepare_for_tensorboard(self, statistics: Statistics):
def _prepare_for_tensorboard(self, stats: NNCFStatistics):
raise NotImplementedError(
'StatisticsCallback class implementation must override the _prepare_for_tensorboard method.')
2 changes: 1 addition & 1 deletion beta/nncf/tensorflow/helpers/callback_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""

from nncf.api.composite_compression import CompositeCompressionAlgorithmController
from nncf.common.composite_compression import CompositeCompressionAlgorithmController
from beta.nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController
from beta.nncf.tensorflow.pruning.callbacks import PruningStatisticsCallback
from beta.nncf.tensorflow.sparsity.callbacks import SparsityStatisticsCallback
Expand Down
3 changes: 0 additions & 3 deletions beta/nncf/tensorflow/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class TFZeroCompressionLoss(CompressionLoss):
def calculate(self, *args, **kwargs) -> Any:
return tf.constant(0.)

def statistics(self, quickly_collected_only: bool = False) -> Dict[str, object]:
return {}

def load_state(self, state: Dict[str, object]) -> None:
pass

Expand Down
2 changes: 1 addition & 1 deletion beta/nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _check_pruning_rate(self, params):
if pruning_flops_target:
self.prune_flops = True

def statistics(self, quickly_collected_only: bool = False) -> PrunedModelStatistics:
def _calculate_pruned_model_stats(self) -> PrunedModelStatistics:
pruning_rates = []
mask_names = []
weights_shapes = []
Expand Down
12 changes: 6 additions & 6 deletions beta/nncf/tensorflow/pruning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""

from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.statistics import NNCFStatistics
from beta.nncf.tensorflow.callbacks.statistics_callback import StatisticsCallback


Expand All @@ -20,17 +20,17 @@ class PruningStatisticsCallback(StatisticsCallback):
Callback for logging pruning compression statistics to tensorboard and stdout.
"""

def _prepare_for_tensorboard(self, statistics: FilterPruningStatistics):
def _prepare_for_tensorboard(self, nncf_stats: NNCFStatistics):
base_prefix = '2.compression/statistics'
detailed_prefix = '3.compression_details/statistics'

ms = statistics.model_statistics # type: SparsifiedModelStatistics
tensorboard_statistics = {
ms = nncf_stats.filter_pruning.model_statistics
tensorboard_stats = {
f'{base_prefix}/pruning_level_for_model': ms.pruning_level,
}

for ls in ms.pruned_layers_summary:
layer_name, pruning_level = ls.name, ls.filter_pruning_level
tensorboard_statistics[f'{detailed_prefix}/{layer_name}/pruning_level'] = pruning_level
tensorboard_stats[f'{detailed_prefix}/{layer_name}/pruning_level'] = pruning_level

return tensorboard_statistics
return tensorboard_stats
11 changes: 8 additions & 3 deletions beta/nncf/tensorflow/pruning/filter_pruning/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from nncf.common.pruning.utils import get_rounded_pruned_element_number
from nncf.common.utils.logger import logger as nncf_logger
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.statistics import NNCFStatistics


@TF_COMPRESSION_ALGORITHMS.register('filter_pruning')
Expand Down Expand Up @@ -132,9 +133,13 @@ def scheduler(self) -> CompressionScheduler:
def loss(self) -> CompressionLoss:
return self._loss

def statistics(self, quickly_collected_only: bool = False) -> FilterPruningStatistics:
model_statistics = super().statistics(quickly_collected_only)
return FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops)
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
model_statistics = self._calculate_pruned_model_stats(quickly_collected_only)
stats = FilterPruningStatistics(model_statistics, self.full_flops, self.current_flops)

nncf_stats = NNCFStatistics()
nncf_stats.register('filter_pruning', stats)
return nncf_stats

def freeze(self):
self.frozen = True
Expand Down
6 changes: 3 additions & 3 deletions beta/nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationConstraints
from nncf.common.compression import StubStatistics
from nncf.common.statistics import NNCFStatistics
from nncf.api.compression import CompressionScheduler
from nncf.api.compression import CompressionLoss

Expand Down Expand Up @@ -248,5 +248,5 @@ def loss(self) -> CompressionLoss:
def initialize(self, dataset=None, loss=None):
self._initializer(self._model, dataset, loss)

def statistics(self, quickly_collected_only: bool = False) -> StubStatistics:
return StubStatistics()
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
return NNCFStatistics()
19 changes: 11 additions & 8 deletions beta/nncf/tensorflow/sparsity/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
limitations under the License.
"""

from typing import Union

import tensorflow as tf

from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics, RBSparsityStatistics
from nncf.common.statistics import NNCFStatistics
from beta.nncf.tensorflow.callbacks.statistics_callback import StatisticsCallback


Expand All @@ -36,18 +34,23 @@ class SparsityStatisticsCallback(StatisticsCallback):
Callback for logging sparsity compression statistics to tensorboard and stdout.
"""

def _prepare_for_tensorboard(self, statistics: Union[MagnitudeSparsityStatistics, RBSparsityStatistics]):
def _prepare_for_tensorboard(self, nncf_stats: NNCFStatistics):
base_prefix = '2.compression/statistics'
detailed_prefix = '3.compression_details/statistics'

ms = statistics.model_statistics # type: SparsifiedModelStatistics
tensorboard_statistics = {
if nncf_stats.magnitude_sparsity:
stats = nncf_stats.magnitude_sparsity
else:
stats = nncf_stats.rb_sparsity

ms = stats.model_statistics
tensorboard_stats = {
f'{base_prefix}/sparsity_level_for_model': ms.sparsity_level,
f'{base_prefix}/sparsity_level_for_sparsified_layers': ms.sparsity_level_for_layers,
}

for ls in ms.sparsified_layers_summary:
layer_name, sparsity_level = ls.name, ls.sparsity_level
tensorboard_statistics[f'{detailed_prefix}/{layer_name}/sparsity_level'] = sparsity_level
tensorboard_stats[f'{detailed_prefix}/{layer_name}/sparsity_level'] = sparsity_level

return tensorboard_statistics
return tensorboard_stats
9 changes: 7 additions & 2 deletions beta/nncf/tensorflow/sparsity/magnitude/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.sparsity.statistics import SparsifiedModelStatistics
from nncf.common.sparsity.statistics import LayerThreshold
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
from nncf.common.statistics import NNCFStatistics
from nncf.api.compression import CompressionScheduler
from nncf.api.compression import CompressionLoss
from beta.nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS
Expand Down Expand Up @@ -187,7 +188,7 @@ def _collect_all_weights(self):
[-1]))
return all_weights

def statistics(self, quickly_collected_only: bool = False) -> MagnitudeSparsityStatistics:
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
sparsity_levels = []
mask_names = []
weights_shapes = []
Expand Down Expand Up @@ -236,4 +237,8 @@ def statistics(self, quickly_collected_only: bool = False) -> MagnitudeSparsityS
sparsity_rate_for_sparsified_modules,
sparsified_layers_summary)

return MagnitudeSparsityStatistics(model_statistics, threshold_statistics)
stats = MagnitudeSparsityStatistics(model_statistics, threshold_statistics)

nncf_stats = NNCFStatistics()
nncf_stats.register('magnitude_sparsity', stats)
return nncf_stats
9 changes: 7 additions & 2 deletions beta/nncf/tensorflow/sparsity/rb/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.common.sparsity.statistics import SparsifiedLayerSummary
from nncf.common.sparsity.statistics import SparsifiedModelStatistics
from nncf.common.sparsity.statistics import RBSparsityStatistics
from nncf.common.statistics import NNCFStatistics
from beta.nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS
from beta.nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from beta.nncf.tensorflow.graph.transformations.commands import TFInsertionCommand
Expand Down Expand Up @@ -121,7 +122,7 @@ def set_sparsity_level(self, sparsity_level):
def freeze(self):
self._loss.disable()

def statistics(self, quickly_collected_only: bool = False) -> RBSparsityStatistics:
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
sparsity_levels = []
mask_names = []
weights_shapes = []
Expand Down Expand Up @@ -171,4 +172,8 @@ def statistics(self, quickly_collected_only: bool = False) -> RBSparsityStatisti
# TODO(andrey-churkin): Should be calculated when the distributed mode will be supported
masks_consistency = 1.0

return RBSparsityStatistics(model_statistics, masks_consistency, target_level, mean_sparse_prob)
stats = RBSparsityStatistics(model_statistics, masks_consistency, target_level, mean_sparse_prob)

nncf_stats = NNCFStatistics()
nncf_stats.register('rb_sparsity', stats)
return nncf_stats
4 changes: 0 additions & 4 deletions beta/nncf/tensorflow/sparsity/rb/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tensorflow as tf

from nncf.api.compression import CompressionLoss
from nncf.common.compression import StubStatistics


class SparseLoss(CompressionLoss):
Expand Down Expand Up @@ -79,6 +78,3 @@ def get_state(self) -> Dict[str, object]:
'disabled': bool(tf.keras.backend.eval(tf.cast(self.disabled, tf.bool))),
'p': self.p
}

def statistics(self, quickly_collected_only: bool = False) -> StubStatistics:
return StubStatistics()
12 changes: 6 additions & 6 deletions beta/tests/tensorflow/sparsity/magnitude/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_magnitude_scheduler_can_do_epoch_step__with_norm():
for expected_level, expected_threshold in zip(expected_levels, expected_thresholds):
scheduler.epoch_step()
assert compression_ctrl.scheduler.current_sparsity_level == expected_level
stats = compression_ctrl.statistics()
for layer_info in stats.thresholds:
nncf_stats = compression_ctrl.statistics()
for layer_info in nncf_stats.magnitude_sparsity.thresholds:
assert layer_info.threshold == pytest.approx(expected_threshold, 0.01)


Expand All @@ -59,14 +59,14 @@ def test_magnitude_scheduler_can_do_epoch_step__with_last():

scheduler.epoch_step(3)
assert compression_ctrl.scheduler.current_sparsity_level == 0.9
stats = compression_ctrl.statistics()
for layer_info in stats.thresholds:
nncf_stats = compression_ctrl.statistics()
for layer_info in nncf_stats.magnitude_sparsity.thresholds:
assert layer_info.threshold == pytest.approx(0.371, 0.01)

scheduler.epoch_step()
assert compression_ctrl.scheduler.current_sparsity_level == 0.9
stats = compression_ctrl.statistics()
for layer_info in stats.thresholds:
nncf_stats = compression_ctrl.statistics()
for layer_info in nncf_stats.magnitude_sparsity.thresholds:
assert layer_info.threshold == pytest.approx(0.371, 0.01)


Expand Down
5 changes: 3 additions & 2 deletions beta/tests/tensorflow/sparsity/rb/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from beta.nncf import NNCFConfig
from beta.nncf import create_compressed_model
from nncf.api.composite_compression import CompositeCompressionAlgorithmController
from nncf.common.composite_compression import CompositeCompressionAlgorithmController
from beta.examples.tensorflow.common.callbacks import get_callbacks, get_progress_bar
from beta.nncf.helpers.callback_creation import create_compression_callbacks

Expand Down Expand Up @@ -148,7 +148,8 @@ def test_rb_sparse_target_lenet(distributed, quantized):
class SparsityRateTestCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
target = sparse_algo.loss.target_sparsity_rate
actual = sparse_algo.statistics().model_statistics.sparsity_level_for_layers
nncf_stats = sparse_algo.statistics()
actual = nncf_stats.rb_sparsity.model_statistics.sparsity_level_for_layers
print(f'target {target}, actual {actual}')
if epoch + 1 <= freeze_epoch:
assert abs(actual - target) < 0.05
Expand Down
Loading

0 comments on commit b5398b0

Please sign in to comment.