Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor NNCF compression statistics reporting API #688

Merged
merged 14 commits into from
May 24, 2021
Prev Previous commit
Next Next commit
Refactoring
  • Loading branch information
andrey-churkin committed May 23, 2021
commit 38708d30c64043bbe7d673cbace26503460fcfe7
7 changes: 4 additions & 3 deletions beta/examples/tensorflow/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf
import numpy as np

from nncf.common.utils.tensorboard import prepare_for_tensorboard
from beta.nncf import create_compressed_model
from beta.nncf.tensorflow.helpers.model_manager import TFOriginalModelManager

Expand Down Expand Up @@ -210,9 +211,9 @@ def train(train_step, test_step, eval_metric, train_dist_dataset, test_dist_data

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
statistics = {'compression/statistics/' + key: value
for key, value in statistics.as_dict().items()
if isinstance(value, (int, float))}
statistics = {
f'compression/statistics/{name}': value for name, value in prepare_for_tensorboard(statistics).items()
}
compression_summary_writer(metrics=statistics,
step=optimizer.iterations.numpy())

Expand Down
7 changes: 4 additions & 3 deletions beta/examples/tensorflow/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf
import numpy as np

from nncf.common.utils.tensorboard import prepare_for_tensorboard
from beta.nncf import create_compressed_model
from beta.nncf.tensorflow.helpers.model_manager import TFOriginalModelManager

Expand Down Expand Up @@ -196,9 +197,9 @@ def train(train_step, train_dist_dataset, initial_epoch, initial_step,

statistics = compression_ctrl.statistics()
logger.info(statistics.as_str())
statistics = {'compression/statistics/' + key: value
for key, value in statistics.as_dict().items()
if isinstance(value, (int, float))}
statistics = {
f'compression/statistics/{name}': value for name, value in prepare_for_tensorboard(statistics).items()
}
compression_summary_writer(metrics=statistics,
step=optimizer.iterations.numpy())

Expand Down
2 changes: 1 addition & 1 deletion beta/nncf/tensorflow/pruning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _prepare_for_tensorboard(self, statistics: FilterPruningStatistics):

ms = statistics.model_statistics # type: SparsifiedModelStatistics
tensorboard_statistics = {
f'{base_prefix}/pruning_level': ms.pruning_level,
f'{base_prefix}/pruning_level_for_model': ms.pruning_level,
}

for ls in ms.pruned_layers_summary:
Expand Down
4 changes: 2 additions & 2 deletions beta/nncf/tensorflow/sparsity/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _prepare_for_tensorboard(self, statistics: Union[MagnitudeSparsityStatistics

ms = statistics.model_statistics # type: SparsifiedModelStatistics
tensorboard_statistics = {
f'{base_prefix}/sparsity_level': ms.sparsity_level,
f'{base_prefix}/sparsity_level_for_layers': ms.sparsity_level_for_layers,
f'{base_prefix}/sparsity_level_for_model': ms.sparsity_level,
f'{base_prefix}/sparsity_level_for_sparsified_layers': ms.sparsity_level_for_layers,
}

for ls in ms.sparsified_layers_summary:
Expand Down
2 changes: 1 addition & 1 deletion beta/tests/tensorflow/sparsity/rb/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_rb_sparse_target_lenet(distributed, quantized):
class SparsityRateTestCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
target = sparse_algo.loss.target_sparsity_rate
actual = compress_algo.statistics().model_statistics.sparsity_level_for_layers
actual = sparse_algo.statistics().model_statistics.sparsity_level_for_layers
print(f'target {target}, actual {actual}')
if epoch + 1 <= freeze_epoch:
assert abs(actual - target) < 0.05
Expand Down
7 changes: 4 additions & 3 deletions examples/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.models import InceptionOutputs

from nncf.common.utils.tensorboard import prepare_for_tensorboard
from examples.common.argparser import get_common_argument_parser
from examples.common.example_logger import logger
from examples.common.execution import ExecutionMode, get_execution_mode, \
Expand Down Expand Up @@ -457,9 +458,9 @@ def train_epoch(train_loader, model, criterion, criterion_fn, optimizer, compres
config.tb.add_scalar("train/top1", top1.avg, i + global_step)
config.tb.add_scalar("train/top5", top5.avg, i + global_step)

for stat_name, stat_value in compression_ctrl.statistics(quickly_collected_only=True).as_dict().items():
if isinstance(stat_value, (int, float)):
config.tb.add_scalar('train/statistics/{}'.format(stat_name), stat_value, i + global_step)
statistics = compression_ctrl.statistics(quickly_collected_only=True)
for stat_name, stat_value in prepare_for_tensorboard(statistics).items():
config.tb.add_scalar('train/statistics/{}'.format(stat_name), stat_value, i + global_step)


def validate(val_loader, model, criterion, config):
Expand Down
7 changes: 4 additions & 3 deletions examples/classification/staged_quantization_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.utils.data.distributed
from torchvision.models import InceptionOutputs

from nncf.common.utils.tensorboard import prepare_for_tensorboard
from examples.classification.main import create_data_loaders, validate, AverageMeter, accuracy, get_lr, \
create_datasets, inception_criterion_fn
from examples.common.example_logger import logger
Expand Down Expand Up @@ -363,9 +364,9 @@ def train_epoch_staged(train_loader, batch_multiplier, model, criterion, criteri
config.tb.add_scalar("train/top1", top1.avg, i + global_step)
config.tb.add_scalar("train/top5", top5.avg, i + global_step)

for stat_name, stat_value in compression_ctrl.statistics(quickly_collected_only=True).as_dict().items():
if isinstance(stat_value, (int, float)):
config.tb.add_scalar('train/statistics/{}'.format(stat_name), stat_value, i + global_step)
statistics = compression_ctrl.statistics(quickly_collected_only=True)
for stat_name, stat_value in prepare_for_tensorboard(statistics).items():
config.tb.add_scalar('train/statistics/{}'.format(stat_name), stat_value, i + global_step)


def get_wd(optimizer):
Expand Down
7 changes: 4 additions & 3 deletions examples/semantic_segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from examples.common.sample_config import create_sample_config
from torch.optim.lr_scheduler import ReduceLROnPlateau

from nncf.common.utils.tensorboard import prepare_for_tensorboard
import examples.semantic_segmentation.utils.data as data_utils
import examples.semantic_segmentation.utils.loss_funcs as loss_funcs
import examples.semantic_segmentation.utils.transforms as JT
Expand Down Expand Up @@ -344,9 +345,9 @@ def train(model, model_without_dp, compression_ctrl, train_loader, val_loader, c
config.tb.add_scalar("train/learning_rate", optimizer.param_groups[0]['lr'], epoch)
config.tb.add_scalar("train/compression_loss", compression_ctrl.loss(), epoch)

for key, value in compression_ctrl.statistics(quickly_collected_only=True).as_dict().items():
if isinstance(value, (int, float)):
config.tb.add_scalar("compression/statistics/{0}".format(key), value, epoch)
statistics = compression_ctrl.statistics(quickly_collected_only=True)
for key, value in prepare_for_tensorboard(statistics).items():
config.tb.add_scalar("compression/statistics/{0}".format(key), value, epoch)

if (epoch + 1) % config.save_freq == 0 or epoch + 1 == config.epochs:
logger.info(">>>> [Epoch: {0:d}] Validation".format(epoch))
Expand Down
9 changes: 0 additions & 9 deletions nncf/api/composite_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@ def as_str(self) -> str:
pretty_string = '\n\n'.join([stats.as_str() for stats in self.child_statistics])
return pretty_string

def as_dict(self) -> Dict[str, Any]:
"""
Calls as_dict() method for all children and returns a sum-up dictionary.
"""
stats = {}
for statistics in self.child_statistics:
stats.update(statistics.as_dict())
return stats


class CompositeCompressionLoss(CompressionLoss):
"""
Expand Down
8 changes: 0 additions & 8 deletions nncf/api/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ def as_str(self) -> str:
:return: A representation of the statistics as a human-readable string.
"""

@abstractmethod
def as_dict(self) -> Dict[str, Any]:
"""
Returns a representation of the statistics as built-in data types.

:return: A representation of the statistics as built-in data types.
"""


class CompressionLoss(ABC):
"""
Expand Down
3 changes: 0 additions & 3 deletions nncf/common/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,3 @@ class is used when statistics are not calculated.

def as_str(self) -> str:
return ''

def as_dict(self) -> Dict[str, Any]:
return {}
37 changes: 1 addition & 36 deletions nncf/common/pruning/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""

from typing import List, Dict, Any
from typing import List

from nncf.api.compression import Statistics
from nncf.common.utils.helpers import create_table
Expand Down Expand Up @@ -46,22 +46,6 @@ def __init__(self,
self.mask_pruning_level = mask_pruning_level
self.filter_pruning_level = filter_pruning_level

def as_dict(self) -> Dict[str, Any]:
"""
Returns a representation of the pruned layer's summary as built-in data types.

:return: A representation of the sparsified layer's summary as built-in data types.
"""
summary = {
'name': self.name,
'weight_shape': self.weight_shape,
'mask_shape': self.mask_shape,
'weight_pruning_level': self.weight_pruning_level,
'mask_pruning_level': self.mask_pruning_level,
'filter_pruning_level': self.filter_pruning_level,
}
return summary


class PrunedModelStatistics(Statistics):
"""
Expand Down Expand Up @@ -107,13 +91,6 @@ def as_str(self) -> str:

return pretty_string

def as_dict(self) -> Dict[str, Any]:
stats = {
'pruning_level': self.pruning_level,
'pruned_layers_summary': [s.as_dict() for s in self.pruned_layers_summary],
}
return stats


class FilterPruningStatistics(Statistics):
"""
Expand Down Expand Up @@ -150,15 +127,3 @@ def as_str(self) -> str:
f'Statistics of the filter pruning algorithm:\n{algorithm_string}'
)
return pretty_string

def as_dict(self) -> Dict[str, Any]:
algorithm = 'filter_pruning'
model_statistics = self.model_statistics.as_dict()
stats = {
f'{algorithm}/pruning_level_for_model': model_statistics['pruning_level'],
f'{algorithm}/pruning_statistic_by_layer': model_statistics['pruned_layers_summary'],
f'{algorithm}/flops_pruning_level': self.flops_pruning_level,
f'{algorithm}/full_flops': self.full_flops,
f'{algorithm}/current_flops': self.current_flops,
}
return stats
58 changes: 1 addition & 57 deletions nncf/common/quantization/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""

from typing import Dict, Any, Optional
from typing import Dict, Optional

from nncf.api.compression import Statistics
from nncf.common.utils.helpers import create_table
Expand Down Expand Up @@ -71,16 +71,6 @@ def as_str(self) -> str:
pretty_string = f'Statistics of the memory consumption:\n{memory_consumption_string}'
return pretty_string

def as_dict(self) -> Dict[str, Any]:
stats = {
'fp32_weight_size': self.fp32_weight_size,
'quantized_weight_size': self.quantized_weight_size,
'max_fp32_activation_size': self.max_fp32_activation_size,
'max_compressed_activation_size': self.max_compressed_activation_size,
'weight_memory_consumption_decrease': self.weight_memory_consumption_decrease,
}
return stats


class QuantizersCounter:
def __init__(self,
Expand All @@ -107,16 +97,6 @@ def __init__(self,
self.num_per_tensor = num_per_tensor
self.num_per_channel = num_per_channel

def as_dict(self) -> Dict[str, Any]:
return {
'num_symmetric': self.num_symmetric,
'num_asymmetric': self.num_asymmetric,
'num_signed': self.num_signed,
'num_unsigned': self.num_unsigned,
'num_per_tensor': self.num_per_tensor,
'num_per_channel': self.num_per_channel,
}


class QuantizationShareStatistics(Statistics):
"""
Expand Down Expand Up @@ -177,17 +157,6 @@ def as_str(self) -> str:
pretty_string = f'Statistics of the quantization share:\n{qshare_string}'
return pretty_string

def as_dict(self) -> Dict[str, Any]:
stats = {
'wq_total_num': self.wq_total_num,
'aq_total_num': self.aq_total_num,
'wq_potential_num': self.wq_potential_num,
'aq_potential_num': self.aq_potential_num,
'wq_counter': self.wq_counter.as_dict(),
'aq_counter': self.aq_counter.as_dict(),
}
return stats


class BitwidthDistributionStatistics(Statistics):
"""
Expand Down Expand Up @@ -233,13 +202,6 @@ def as_str(self) -> str:
pretty_string = f'Statistics of the bitwidth distribution:\n{distribution_string}'
return pretty_string

def as_dict(self) -> Dict[str, Any]:
stats = {
'num_wq_per_bitwidth': self.num_wq_per_bitwidth,
'num_aq_per_bitwidth': self.num_aq_per_bitwidth,
}
return stats


class QuantizationConfigurationStatistics(Statistics):
"""
Expand Down Expand Up @@ -268,13 +230,6 @@ def as_str(self) -> str:
pretty_string = f'Statistics of the quantization configuration:\n{qc_string}'
return pretty_string

def as_dict(self) -> Dict[str, Any]:
stats = {
'quantized_edges_in_cfg': self.quantized_edges_in_cfg,
'total_edges_in_cfg': self.total_edges_in_cfg,
}
return stats


class QuantizationStatistics(Statistics):
"""
Expand Down Expand Up @@ -325,14 +280,3 @@ def as_str(self) -> str:
'\n\n'.join(pretty_strings)
)
return pretty_string

def as_dict(self) -> Dict[str, Any]:
algorithm = 'quantization'
stats = {
f'{algorithm}/ratio_of_enabled_quantizations': self.ratio_of_enabled_quantizations,
f'{algorithm}/quantization_share_statistics': self.quantization_share_statistics.as_dict(),
f'{algorithm}/bitwidth_distribution_statistics': self.bitwidth_distribution_statistics.as_dict(),
f'{algorithm}/memory_consumption_statistics': self.memory_consumption_statistics.as_dict(),
f'{algorithm}/quantization_configuration_statistics': self.quantization_configuration_statistics.as_dict(),
}
return stats
Loading