diff --git a/examples/classification/main.py b/examples/classification/main.py index 2385596a5d3..5cf5ea51ed6 100644 --- a/examples/classification/main.py +++ b/examples/classification/main.py @@ -13,6 +13,7 @@ import os.path as osp import sys import time +from copy import deepcopy from pathlib import Path from typing import Any @@ -347,9 +348,12 @@ def create_train_data_loader(batch_size_): train_loader = create_train_data_loader(batch_size) - init_loader = train_loader - if config.batch_size_init and config.batch_size_init != config.batch_size: + if config.batch_size_init: init_loader = create_train_data_loader(config.batch_size_init) + else: + init_loader = deepcopy(train_loader) + if config.distributed: + init_loader.num_workers = 0 # PyTorch multiprocessing dataloader issue WA return train_loader, train_sampler, val_loader, init_loader diff --git a/examples/object_detection/main.py b/examples/object_detection/main.py index 5e809f6cba8..20e58dc8603 100644 --- a/examples/object_detection/main.py +++ b/examples/object_detection/main.py @@ -14,6 +14,7 @@ import os.path as osp import sys import time +from copy import deepcopy from pathlib import Path import torch @@ -224,9 +225,12 @@ def create_train_data_loader(batch_size): ) train_data_loader = create_train_data_loader(config.batch_size) - init_data_loader = train_data_loader if config.batch_size_init: init_data_loader = create_train_data_loader(config.batch_size_init) + else: + init_data_loader = deepcopy(train_data_loader) + if config.distributed: + init_data_loader.num_workers = 0 # PyTorch multiprocessing dataloader issue WA test_dataset = get_testing_dataset(config.dataset, config.test_anno, config.test_imgs, config) logger.info("Loaded {} testing images".format(len(test_dataset))) diff --git a/examples/semantic_segmentation/configs/unet_mapillary_magnitude_sparsity_int8.json b/examples/semantic_segmentation/configs/unet_mapillary_magnitude_sparsity_int8.json index b3d7433ab7c..8cf83893c06 100644 --- a/examples/semantic_segmentation/configs/unet_mapillary_magnitude_sparsity_int8.json +++ b/examples/semantic_segmentation/configs/unet_mapillary_magnitude_sparsity_int8.json @@ -1,6 +1,7 @@ { "model": "unet", "dataset" : "mapillary", + "batch_size_init": 2, "preprocessing": { "resize": { "height": 512, diff --git a/examples/semantic_segmentation/main.py b/examples/semantic_segmentation/main.py index 25aa53fd040..01abb6300f1 100644 --- a/examples/semantic_segmentation/main.py +++ b/examples/semantic_segmentation/main.py @@ -16,6 +16,7 @@ # https://github.com/pytorch/vision/tree/master/references/segmentation import sys +from copy import deepcopy from os import path as osp import functools @@ -183,9 +184,12 @@ def create_train_data_loader(batch_size_): collate_fn=data_utils.collate_fn, drop_last=True) # Loaders train_loader = create_train_data_loader(batch_size) - init_loader = train_loader if config.batch_size_init: init_loader = create_train_data_loader(config.batch_size_init) + else: + init_loader = deepcopy(train_loader) + if config.distributed: + init_loader.num_workers = 0 # PyTorch multiprocessing dataloader issue WA val_sampler = torch.utils.data.SequentialSampler(val_set) val_loader = torch.utils.data.DataLoader( diff --git a/nncf/pruning/filter_pruning/algo.py b/nncf/pruning/filter_pruning/algo.py index ca20656555b..eb7e0cf4d43 100644 --- a/nncf/pruning/filter_pruning/algo.py +++ b/nncf/pruning/filter_pruning/algo.py @@ -238,7 +238,7 @@ def _find_layerwise_pruning_rate(self, target_flops_pruning_rate): return right raise RuntimeError("Can't prune model to asked flops pruning rate = {}".format(target_flops_pruning_rate)) - def set_pruning_rate(self, pruning_rate): + def set_pruning_rate(self, pruning_rate, run_batchnorm_adaptation=False): # Pruning rate from scheduler can be flops pruning rate or percentage of params that should be pruned self.pruning_rate = pruning_rate if not self.frozen: @@ -257,7 +257,8 @@ def set_pruning_rate(self, pruning_rate): if self.zero_grad: self.zero_grads_for_pruned_modules() self._apply_masks() - self.run_batchnorm_adaptation(self.config) + if run_batchnorm_adaptation: + self.run_batchnorm_adaptation(self.config) def _set_binary_masks_for_filters(self, pruning_rate): nncf_logger.debug("Setting new binary masks for pruned modules.") diff --git a/nncf/sparsity/magnitude/algo.py b/nncf/sparsity/magnitude/algo.py index a3cc5e142d2..4c454973a98 100644 --- a/nncf/sparsity/magnitude/algo.py +++ b/nncf/sparsity/magnitude/algo.py @@ -81,7 +81,9 @@ def freeze(self): for layer in self.sparsified_module_info: layer.operand.frozen = True - def set_sparsity_level(self, sparsity_level, target_sparsified_module_info: SparseModuleInfo = None): + def set_sparsity_level(self, sparsity_level, + target_sparsified_module_info: SparseModuleInfo = None, + run_batchnorm_adaptation: bool = False): if sparsity_level >= 1 or sparsity_level < 0: raise AttributeError( 'Sparsity level should be within interval [0,1), actual value to set is: {}'.format(sparsity_level)) @@ -91,7 +93,8 @@ def set_sparsity_level(self, sparsity_level, target_sparsified_module_info: Spar target_sparsified_module_info_list = [target_sparsified_module_info] threshold = self._select_threshold(sparsity_level, target_sparsified_module_info_list) self._set_masks_for_threshold(threshold, target_sparsified_module_info_list) - self.run_batchnorm_adaptation(self.config) + if run_batchnorm_adaptation: + self.run_batchnorm_adaptation(self.config) def _select_threshold(self, sparsity_level, target_sparsified_module_info_list): all_weights = self._collect_all_weights(target_sparsified_module_info_list) diff --git a/tests/test_sota_checkpoints.py b/tests/test_sota_checkpoints.py index 660b4378f22..bbe043787c6 100644 --- a/tests/test_sota_checkpoints.py +++ b/tests/test_sota_checkpoints.py @@ -201,7 +201,7 @@ def write_results_table(self, init_table_string): f.close() @staticmethod - def threshold_check(err, diff_target, diff_fp32_min_=None, diff_fp32_max_=None, fp32_metric=None, + def threshold_check(is_ok, diff_target, diff_fp32_min_=None, diff_fp32_max_=None, fp32_metric=None, diff_fp32=None, diff_target_min=None, diff_target_max=None): color = BG_COLOR_RED_HEX within_thresholds = False @@ -213,7 +213,7 @@ def threshold_check(err, diff_target, diff_fp32_min_=None, diff_fp32_max_=None, diff_fp32_min_ = DIFF_FP32_MIN_GLOBAL if not diff_fp32_max_: diff_fp32_max_ = DIFF_FP32_MAX_GLOBAL - if err is None: + if is_ok: if fp32_metric is not None: if diff_fp32_min_ < diff_fp32 < diff_fp32_max_ and diff_target_min < diff_target < diff_target_max: color = BG_COLOR_GREEN_HEX @@ -335,7 +335,7 @@ def test_eval(self, sota_checkpoints_dir, sota_data_dir, eval_test_struct: EvalR cmd += " -b {}".format(eval_test_struct.batch_) exit_code, err_str = self.run_cmd(cmd, cwd=PROJECT_ROOT) - is_ok = (exit_code == 0 and err_str is None) + is_ok = (exit_code == 0 and metrics_dump_file_path.exists()) if is_ok: metric_value = self.read_metric(str(metrics_dump_file_path)) else: @@ -366,7 +366,7 @@ def test_eval(self, sota_checkpoints_dir, sota_data_dir, eval_test_struct: EvalR diff_target, fp32_metric, diff_fp32) - retval = self.threshold_check(err_str, + retval = self.threshold_check(is_ok, diff_target, eval_test_struct.diff_fp32_min_, eval_test_struct.diff_fp32_max_, @@ -444,7 +444,8 @@ def test_train(self, sota_data_dir, config_name_, expected_, metric_type_, datas diff_target = round((metric_value - expected_), 2) self.row_dict[model_name_] = self.make_table_row(test, expected_, metric_type_, model_name_, err_str, metric_value, diff_target) - self.color_dict[model_name_], is_accuracy_within_thresholds = self.threshold_check(err_str, diff_target) + self.color_dict[model_name_], is_accuracy_within_thresholds = self.threshold_check(err_str is not None, + diff_target) assert is_accuracy_within_thresholds