Skip to content

Commit

Permalink
Add dump_checkpoint_fn interface into AA to control dumping checkpoin…
Browse files Browse the repository at this point in the history
…ts (openvinotoolkit#924)

This PR extends current Accuracy Aware interface and introduce a new user defined optional function - `dump_checkpoint_fn(model, compression_controller, accuracy_aware_runner, save_dir)`. 
    An (optional) function that allows a user to control the model's checkpoint saving process.
    Training loop will call this function instead own dump_checkpoint function

Co-authored-by: Andrey Churkin <andrey.churkin@intel.com>
  • Loading branch information
kshpv and andrey-churkin committed Oct 6, 2021
1 parent 1bd5282 commit dc975f6
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 52 deletions.
13 changes: 12 additions & 1 deletion docs/Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,24 @@ def configure_optimizers_fn():
of an optimizer instance and an LR scheduler instance (replace with None if the latter
is not applicable).
'''

def dump_checkpoint_fn(model, compression_controller, accuracy_aware_runner, save_dir):
'''
An (optional) function that allows a user to define how to save the model's checkpoint.
Training loop will call this function instead own dump_checkpoint function and pass
`model`, `compression_controller`, `accuracy_aware_runner` and `save_dir` to it as arguments.
The user can save the states of the objects according to their own needs.
`save_dir` is a directory that Accuracy-Aware pipeline created to store log information.
'''
```

Once the above functions are defined, you could pass them to the `run` method of the earlier created training loop :
```python

model = training_loop.run(model,
train_epoch_fn=train_epoch_fn,
validate_fn=validate_fn,
configure_optimizers_fn=configure_optimizers_fn)
configure_optimizers_fn=configure_optimizers_fn,
dump_checkpoint_fn=dump_checkpoint_fn)
```
The above call executes the acccuracy-aware training loop and return the compressed model. For more details on how to use the accuracy-aware training loop functionality of NNCF, please refer to its [documentation](./accuracy_aware_model_training/AdaptiveCompressionTraining.md).
17 changes: 7 additions & 10 deletions nncf/common/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
limitations under the License.
"""

from typing import Callable
from typing import Dict
from typing import TypeVar
from abc import ABC
Expand Down Expand Up @@ -104,17 +105,18 @@ def calculate_minimal_tolerable_accuracy(self, uncompressed_model_accuracy):
"""

@abstractmethod
def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn,
def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn, dump_checkpoint_fn,
tensorboard_writer=None, log_dir=None):
"""
Register the user-supplied functions to be used to control the training process.
:param train_epoch_fn: a method to fine-tune the model for a single epoch
(to be called inside the `train_epoch` of the TrainingRunner).
:param validate: a method to evaluate the model on the validation dataset
:param validate_fn: a method to evaluate the model on the validation dataset
(to be called inside the `train_epoch` of the TrainingRunner).
:param configure_optimizers_fn: a method to instantiate an optimizer and a learning
rate scheduler (to be called inside the `configure_optimizers` of the TrainingRunner).
:param dump_checkpoint_fn: a method to dump a checkpoint.
:param tensorboard_writer: The tensorboard object to be used for logging.
:param log_dir: The path to be used for logging and checkpoint saving.
"""
Expand Down Expand Up @@ -236,10 +238,13 @@ def __init__(self, accuracy_aware_params: Dict[str, object], verbose=True,
self.best_val_metric_value = 0

def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn,
dump_checkpoint_fn: Callable[
[ModelType, CompressionAlgorithmController, TrainingRunner, str], None],
tensorboard_writer=None, log_dir=None):
self._train_epoch_fn = train_epoch_fn
self._validate_fn = validate_fn
self._configure_optimizers_fn = configure_optimizers_fn
self._dump_checkpoint_fn = dump_checkpoint_fn
self._tensorboard_writer = tensorboard_writer
self._log_dir = log_dir

Expand Down Expand Up @@ -274,11 +279,3 @@ def __init__(self, accuracy_aware_params: Dict[str, object], verbose=True,
self._best_checkpoints = {}
self.compression_rate_target = None
self.was_compression_increased_on_prev_step = None

def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn,
tensorboard_writer=None, log_dir=None):
self._train_epoch_fn = train_epoch_fn
self._validate_fn = validate_fn
self._configure_optimizers_fn = configure_optimizers_fn
self._tensorboard_writer = tensorboard_writer
self._log_dir = log_dir
19 changes: 10 additions & 9 deletions nncf/common/accuracy_aware_training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class TrainingLoop(ABC):
"""

@abstractmethod
def run(self, model: ModelType, train_epoch_fn, validate_fn,
configure_optimizers_fn=None, tensorboard_writer=None, log_dir=None):
def run(self, model: ModelType, train_epoch_fn, validate_fn, configure_optimizers_fn=None,
dump_checkpoint_fn=None, tensorboard_writer=None, log_dir=None):
"""
Implements the custom logic to run a training loop for model fine-tuning
by using the provided `train_epoch_fn`, `validate_fn` and `configure_optimizers_fn` methods.
Expand All @@ -52,8 +52,9 @@ def run(self, model: ModelType, train_epoch_fn, validate_fn,
:param model: The model instance before fine-tuning
:param train_epoch_fn: a method to fine-tune the model for a single epoch
(to be called inside the `train_epoch` of the TrainingRunner)
:param validate: a method to evaluate the model on the validation dataset
:param validate_fn: a method to evaluate the model on the validation dataset
(to be called inside the `train_epoch` of the TrainingRunner)
:param dump_checkpoint_fn: a method to dump a checkpoint
:param configure_optimizers_fn: a method to instantiate an optimizer and a learning
rate scheduler (to be called inside the `configure_optimizers` of the TrainingRunner)
:return: The fine-tuned model
Expand Down Expand Up @@ -89,10 +90,10 @@ def __init__(self,
dump_checkpoints))
self.compression_controller = compression_controller

def run(self, model, train_epoch_fn, validate_fn,
configure_optimizers_fn=None, tensorboard_writer=None, log_dir=None):
def run(self, model, train_epoch_fn, validate_fn, configure_optimizers_fn=None,
dump_checkpoint_fn=None, tensorboard_writer=None, log_dir=None):
self.runner.initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn,
tensorboard_writer, log_dir)
dump_checkpoint_fn, tensorboard_writer, log_dir)
self.runner.retrieve_uncompressed_model_accuracy(model)
uncompressed_model_accuracy = self.runner.uncompressed_model_accuracy
self.runner.calculate_minimal_tolerable_accuracy(uncompressed_model_accuracy)
Expand Down Expand Up @@ -193,10 +194,10 @@ def remove_registry_prefix(algo_name):
raise RuntimeError('No compression algorithm that supports adaptive compression '
'accuracy-aware training was specified')

def run(self, model, train_epoch_fn, validate_fn,
configure_optimizers_fn=None, tensorboard_writer=None, log_dir=None):
def run(self, model, train_epoch_fn, validate_fn, configure_optimizers_fn=None,
dump_checkpoint_fn=None, tensorboard_writer=None, log_dir=None):
self.runner.initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn,
tensorboard_writer, log_dir)
dump_checkpoint_fn, tensorboard_writer, log_dir)
self.runner.retrieve_uncompressed_model_accuracy(model)
uncompressed_model_accuracy = self.runner.uncompressed_model_accuracy
self.runner.calculate_minimal_tolerable_accuracy(uncompressed_model_accuracy)
Expand Down
4 changes: 2 additions & 2 deletions nncf/tensorflow/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class TFAccuracyAwareTrainingRunner(BaseAccuracyAwareTrainingRunner):
"""

def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn=None,
tensorboard_writer=None, log_dir=None):
super().initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn,
dump_checkpoint_fn=None, tensorboard_writer=None, log_dir=None):
super().initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn, dump_checkpoint_fn,
tensorboard_writer=tensorboard_writer, log_dir=log_dir)
self._log_dir = self._log_dir if self._log_dir is not None \
else 'runs'
Expand Down
52 changes: 24 additions & 28 deletions nncf/torch/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

try:
from torch.utils.tensorboard import SummaryWriter

TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
Expand All @@ -28,6 +29,7 @@
import matplotlib.pyplot as plt
import PIL.Image
from torchvision.transforms import ToTensor

IMG_PACKAGES_AVAILABLE = True
except ImportError:
IMG_PACKAGES_AVAILABLE = False
Expand Down Expand Up @@ -59,8 +61,8 @@ def __init__(self, accuracy_aware_training_params,
self.lr_updates_needed = lr_updates_needed

def initialize_training_loop_fns(self, train_epoch_fn, validate_fn, configure_optimizers_fn,
tensorboard_writer=None, log_dir=None):
super().initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn,
dump_checkpoint_fn, tensorboard_writer=None, log_dir=None):
super().initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn, dump_checkpoint_fn,
tensorboard_writer=tensorboard_writer, log_dir=log_dir)
self._log_dir = self._log_dir if self._log_dir is not None \
else 'runs'
Expand Down Expand Up @@ -96,7 +98,7 @@ def validate(self, model):
with torch.no_grad():
self.current_val_metric_value = self._validate_fn(model, epoch=self.cumulative_epoch_count)
is_better_by_accuracy = (not self.is_higher_metric_better) != (
self.current_val_metric_value > self.best_val_metric_value)
self.current_val_metric_value > self.best_val_metric_value)
if is_better_by_accuracy:
self.best_val_metric_value = self.current_val_metric_value

Expand Down Expand Up @@ -133,25 +135,30 @@ def dump_statistics(self, model, compression_controller):
self.add_tensorboard_scalar('compression/statistics/{0}'.format(key),
value, self.cumulative_epoch_count)


def dump_checkpoint(self, model, compression_controller):
checkpoint_path = osp.join(self._checkpoint_save_dir, 'acc_aware_checkpoint_last.pth')
checkpoint = {
'epoch': self.cumulative_epoch_count + 1,
'state_dict': model.state_dict(),
'compression_state': compression_controller.get_compression_state(),
'best_metric_val': self.best_val_metric_value,
'current_val_metric_value': self.current_val_metric_value,
'optimizer': self.optimizer.state_dict(),
'scheduler': compression_controller.scheduler.get_state()
}
torch.save(checkpoint, checkpoint_path)
def _save_best_checkpoint(self, checkpoint_path):
if self.best_val_metric_value == self.current_val_metric_value:
best_checkpoint_filename = 'acc_aware_checkpoint_best.pth'
best_path = osp.join(self._checkpoint_save_dir, best_checkpoint_filename)
self._best_checkpoint = best_path
copyfile(checkpoint_path, best_path)

def dump_checkpoint(self, model, compression_controller):
if self._dump_checkpoint_fn is not None:
self._dump_checkpoint_fn(model, compression_controller, self, self._log_dir)
else:
checkpoint = {
'epoch': self.cumulative_epoch_count + 1,
'state_dict': model.state_dict(),
'compression_state': compression_controller.get_compression_state(),
'best_metric_val': self.best_val_metric_value,
'current_val_metric_value': self.current_val_metric_value,
'optimizer': self.optimizer.state_dict(),
'scheduler': compression_controller.scheduler.get_state()
}
checkpoint_path = osp.join(self._checkpoint_save_dir, 'acc_aware_checkpoint_last.pth')
torch.save(checkpoint, checkpoint_path)
self._save_best_checkpoint(checkpoint_path)

def add_tensorboard_scalar(self, key, data, step):
if self.verbose and self._tensorboard_writer is not None:
self._tensorboard_writer.add_scalar(key, data, step)
Expand Down Expand Up @@ -225,18 +232,7 @@ def update_training_history(self, compression_rate, best_metric_value):
image,
global_step=len(self.compressed_training_history))

def dump_checkpoint(self, model, compression_controller):
checkpoint_path = osp.join(self._checkpoint_save_dir, 'acc_aware_checkpoint_last.pth')
checkpoint = {
'epoch': self.cumulative_epoch_count + 1,
'state_dict': model.state_dict(),
'compression_state': compression_controller.get_compression_state(),
'best_metric_val': self.best_val_metric_value,
'current_val_metric_value': self.current_val_metric_value,
'optimizer': self.optimizer.state_dict(),
'scheduler': compression_controller.scheduler.get_state()
}
torch.save(checkpoint, checkpoint_path)
def _save_best_checkpoint(self, checkpoint_path):
if self.best_val_metric_value == self.current_val_metric_value:
best_checkpoint_filename = 'acc_aware_checkpoint_best_compression_rate_' \
'{comp_rate:.3f}.pth'.format(comp_rate=self.compression_rate_target)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/accuracy_aware_training/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def configure_optimizers_fn():
optimizer = SGD(model.parameters(), lr=learning_rate)
return optimizer, None

runner.initialize_training_loop_fns(train_fn, validate_fn, configure_optimizers_fn)
runner.initialize_training_loop_fns(train_fn, validate_fn, configure_optimizers_fn, None)
runner.reset_training()
runner.train_epoch(model, compression_ctrl)
metric_value = runner.validate(model)
Expand Down
81 changes: 80 additions & 1 deletion tests/torch/accuracy_aware_training/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def mock_validate_fn(model, init_step=False, epoch=0):
epoch_counter = epoch
if "maximal_relative_accuracy_degradation" in max_accuracy_degradation:
return original_metric * (1 - 0.01 * max_accuracy_degradation['maximal_relative_accuracy_degradation']) * (
epoch / exit_epoch_number)
epoch / exit_epoch_number)
return (original_metric - max_accuracy_degradation['maximal_absolute_accuracy_degradation']) * \
epoch / exit_epoch_number

Expand Down Expand Up @@ -263,3 +263,82 @@ def configure_optimizers_fn():
configure_optimizers_fn=configure_optimizers_fn)
# Epoch number starts from 0
assert epoch_counter == exit_epoch_number


@pytest.mark.parametrize('aa_config', (
{
"accuracy_aware_training": {
"mode": "early_exit",
"params": {
"maximal_relative_accuracy_degradation": 1,
"maximal_total_epochs": 1
}
},
"compression": [
{
"algorithm": "filter_pruning",
},
{
"algorithm": "rb_sparsity",
}
]
},
{
"accuracy_aware_training": {
"mode": "adaptive_compression_level",
"params": {
"maximal_relative_accuracy_degradation": 1,
"maximal_total_epochs": 1
}
},
"compression": [
{
"algorithm": "filter_pruning",
}
]
}
)
)
def test_mock_dump_checkpoint(aa_config):
is_called_dump_checkpoint_fn = False

def mock_dump_checkpoint_fn(model, compression_controller, accuracy_aware_runner, aa_log_dir):
from nncf.api.compression import CompressionAlgorithmController
from nncf.common.accuracy_aware_training.runner import TrainingRunner
assert isinstance(model, torch.nn.Module)
assert isinstance(compression_controller, CompressionAlgorithmController)
assert isinstance(accuracy_aware_runner, TrainingRunner)
assert isinstance(aa_log_dir, str)
nonlocal is_called_dump_checkpoint_fn
is_called_dump_checkpoint_fn = True

config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1])
train_loader = create_ones_mock_dataloader(aa_config, num_samples=10)
model = LeNet()
config.update(aa_config)

def train_fn(compression_ctrl, model, epoch, optimizer, lr_scheduler,
train_loader=train_loader):
pass

def mock_validate_fn(model, init_step=False, epoch=0):
return 80

def configure_optimizers_fn():
optimizer = SGD(model.parameters(), lr=0.001)
return optimizer, None

config = register_default_init_args(config,
train_loader=train_loader,
model_eval_fn=partial(mock_validate_fn, init_step=True))

model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

early_stopping_training_loop = EarlyExitCompressionTrainingLoop(config, compression_ctrl,
dump_checkpoints=True)
model = early_stopping_training_loop.run(model,
train_epoch_fn=train_fn,
validate_fn=partial(mock_validate_fn),
configure_optimizers_fn=configure_optimizers_fn,
dump_checkpoint_fn=mock_dump_checkpoint_fn)
assert is_called_dump_checkpoint_fn

0 comments on commit dc975f6

Please sign in to comment.