Skip to content

Commit

Permalink
Add prepare_for_inference function for PyTorch models (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#1526)

### Changes

1. Add `prepare_for_inference` that convert compressed model to inference in PyTorch format without NNCF specific operations.
    - Convert `AsymmetricQuantizer` and `SymmetricQuantizer` to `FakeQuantize`.
    - Apply filter pruning masks to weights. by filing zeroes.
    - Apply sparsity binary masks to weights.
2. The class has also been fixed `ModelPruner`.
3. Fixed `get_scale_zp_from_input_low_input_high`, the zero_point type was incorrectly shifted by -1. 

### Reason for changes

Converting model to OpenVINO format directory.

### Related tickets

CVS-92247

### Tests

- test_converting_symmetric_quantizer
- test_converting_asymmetric_quantizer
- test_prepare_for_inference_quantization
- test_prepare_for_inference_pruning
- test_prepare_for_inference_quantization_and_pruning
- test_save_original_model

Added `data_generators.py` to create input tensors with values in the middle of the quant points. To detect errors in this points that happens by difference in calculation in nncf and torch quantization formulas. 



<img src="https://user-images.githubusercontent.com/48012821/218319569-32339bc2-2790-4a3b-9b23-9bb480ffed16.png" height="400">
  • Loading branch information
AlexanderDokuchaev committed Feb 24, 2023
1 parent 7bad8e9 commit f79e9f7
Show file tree
Hide file tree
Showing 21 changed files with 1,838 additions and 304 deletions.
19 changes: 12 additions & 7 deletions examples/torch/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This sample demonstrates a DL model compression in case of an image-classificati

## Installation

At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).
At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).

To work with the sample you should install the corresponding Python package dependencies:

Expand All @@ -36,32 +36,36 @@ To prepare the ImageNet dataset, refer to the following [tutorial](https://githu

#### Test Pretrained Model

Before compressing a model, it is highly recommended checking the accuracy of the pretrained model. All models which are supported in the sample has pretrained weights for ImageNet.
Before compressing a model, it is highly recommended checking the accuracy of the pretrained model. All models which are supported in the sample has pretrained weights for ImageNet.

To load pretrained weights into a model and then evaluate the accuracy of that model, make sure that the pretrained=True option is set in the configuration file and use the following command:

```bash
python main.py \
--mode=test \
--config=configs/quantization/mobilenet_v2_imagenet_int8.json \
--data=<path_to_imagenet_dataset> \
--disable-compression
--disable-compression
```

#### Compress Pretrained Model

- Run the following command to start compression with fine-tuning on GPUs:
```
python main.py -m train --config configs/quantization/mobilenet_v2_imagenet_int8.json --data /data/imagenet/ --log-dir=../../results/quantization/mobilenet_v2_int8/
```
It may take a few epochs to get the baseline accuracy results.
```
python main.py -m train --config configs/quantization/mobilenet_v2_imagenet_int8.json --data /data/imagenet/ --log-dir=../../results/quantization/mobilenet_v2_int8/
```

It may take a few epochs to get the baseline accuracy results.
- Use the `--multiprocessing-distributed` flag to run in the distributed mode.
- Use the `--resume` flag with the path to a previously saved model to resume training.
- For Torchvision-supported image classification models, set `"pretrained": true` inside the NNCF config JSON file supplied via `--config` to initialize the model to be compressed with Torchvision-supplied pretrained weights, or, alternatively:
- Use the `--weights` flag with the path to a compatible PyTorch checkpoint in order to load all matching weights from the checkpoint into the model - useful if you need to start compression-aware training from a previously trained uncompressed (FP32) checkpoint instead of performing compression-aware training from scratch.
- Use `--prepare-for-inference` argument to convert model to torch native format before `test` and `export` steps.

#### Validate Your Model Checkpoint

To estimate the test scores of your trained model checkpoint, use the following command:

```
python main.py -m test --config=configs/quantization/mobilenet_v2_imagenet_int8.json --resume <path_to_trained_model_checkpoint>
```
Expand All @@ -71,6 +75,7 @@ python main.py -m test --config=configs/quantization/mobilenet_v2_imagenet_int8.
#### Export Compressed Model

To export trained model to the ONNX format, use the following command:

```
python main.py -m export --config=configs/quantization/mobilenet_v2_imagenet_int8.json --resume=../../results/quantization/mobilenet_v2_int8/6/checkpoints/epoch_1.pth --to-onnx=../../results/mobilenet_v2_int8.onnx
```
Expand Down
26 changes: 17 additions & 9 deletions examples/torch/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,24 @@
from typing import Any

import torch
from torch.backends import cudnn
from torch.cuda.amp.autocast_mode import autocast
from torch import nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch import nn
from torch.backends import cudnn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.models import InceptionOutputs

from examples.torch.common.argparser import parse_args
from examples.torch.common.argparser import get_common_argument_parser
from examples.torch.common.argparser import parse_args
from examples.torch.common.example_logger import logger
from examples.torch.common.execution import ExecutionMode
from examples.torch.common.execution import get_execution_mode
Expand Down Expand Up @@ -69,9 +69,9 @@
from examples.torch.common.utils import print_args
from examples.torch.common.utils import write_metrics
from nncf.api.compression import CompressionStage
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.common.utils.tensorboard import prepare_for_tensorboard
from nncf.config.utils import is_accuracy_aware_training
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.torch import create_compressed_model
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.dynamic_graph.graph_tracer import create_input_infos
Expand Down Expand Up @@ -132,7 +132,8 @@ def main(argv):
if not is_staged_quantization(config):
start_worker(main_worker, config)
else:
from examples.torch.classification.staged_quantization_worker import staged_quantization_main_worker #pylint: disable=cyclic-import
from examples.torch.classification.staged_quantization_worker import \
staged_quantization_main_worker # pylint: disable=cyclic-import
start_worker(staged_quantization_main_worker, config)


Expand Down Expand Up @@ -222,6 +223,8 @@ def model_eval_fn(model):
load_state(model, model_state_dict, is_resume=True)

if is_export_only:
if config.prepare_for_inference:
compression_ctrl.prepare_for_inference(make_model_copy=False)
compression_ctrl.export_model(config.to_onnx)
logger.info("Saved to {}".format(config.to_onnx))
return
Expand Down Expand Up @@ -289,11 +292,16 @@ def configure_optimizers_fn():
train_loader, train_sampler, val_loader, best_acc1)

if 'test' in config.mode:
validate(val_loader, model, criterion, config)
val_model = model
if config.prepare_for_inference:
val_model = compression_ctrl.prepare_for_inference(make_model_copy=True)
validate(val_loader, val_model, criterion, config)

config.mlflow.end_run()

if 'export' in config.mode:
if config.prepare_for_inference:
compression_ctrl.prepare_for_inference(make_model_copy=False)
compression_ctrl.export_model(config.to_onnx)
logger.info("Saved to {}".format(config.to_onnx))

Expand Down
5 changes: 5 additions & 0 deletions examples/torch/common/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def get_common_argument_parser():
parser.add_argument('--to-onnx', type=str, metavar='PATH', default=None,
help='Export to ONNX model by given path')

parser.add_argument(
"--prepare-for-inference",
action='store_true',
help="Convert model to torch native format for export and test steps.")

# Display
parser.add_argument('-p', '--print-freq', default=10, type=int,
metavar='N', help='Print frequency (batch iterations). '
Expand Down
20 changes: 13 additions & 7 deletions examples/torch/object_detection/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Object Detection sample
This sample demonstrates DL model compression capabailites for object detection task.

This sample demonstrates DL model compression capabilities for object detection task.

## Features:

- Vanilla SSD300 / SSD512 (+ Batch Normalization), MobileNetSSD-300
- VOC2007 / VOC2012, COCO datasets
- Configuration file examples for sparsity, quantization, filter pruning and quantization with sparsity
Expand All @@ -11,7 +13,7 @@ This sample demonstrates DL model compression capabailites for object detection

## Installation

At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).
At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).

To work with the sample you should install the corresponding Python package dependencies:

Expand All @@ -20,39 +22,43 @@ pip install -r examples/torch/requirements.txt
```

## Quantize FP32 pretrained model

This scenario demonstrates quantization with fine-tuning of SSD300 on VOC dataset.

#### Dataset preparation

- Download and extract in one folder train/val+test VOC2007 and train/val VOC2012 data from [here](https://pjreddie.com/projects/pascal-voc-dataset-mirror/)
- In the future, `<path_to_dataset>` means the path to this folder.

#### Run object detection sample

- If you did not install the package then add the repository root folder to the `PYTHONPATH` environment variable
- Navigate to the `examples/torch/object_detection` folder
- (Optional) Before compressing a model, it is highly recommended checking the accuracy of the pretrained model, use the following command:
- (Optional) Before compressing a model, it is highly recommended checking the accuracy of the pretrained model, use the following command:
```bash
python main.py \
--mode=test \
--config=configs/ssd300_vgg_voc_int8.json \
--data=<path_to_dataset> \
--disable-compression
--disable-compression
```
- Run the following command to start compression with fine-tuning on GPUs:
`python main.py -m train --config configs/ssd300_vgg_voc_int8.json --data <path_to_dataset> --log-dir=../../results/quantization/ssd300_int8 --weights=<path_to_checkpoint>`
It may take a few epochs to get the baseline accuracy results.
`python main.py -m train --config configs/ssd300_vgg_voc_int8.json --data <path_to_dataset> --log-dir=../../results/quantization/ssd300_int8 --weights=<path_to_checkpoint>`It may take a few epochs to get the baseline accuracy results.
- Use `--weights` flag with the path to a compatible PyTorch checkpoint in order to load all matching weights from the checkpoint into the model - useful if you need to start compression-aware training from a previously trained uncompressed (FP32) checkpoint instead of performing compression-aware training from scratch. This flag is optional, but highly recommended to use.
- Use `--multiprocessing-distributed` flag to run in the distributed mode.
- Use `--resume` flag with the path to a previously saved model to resume training.

- Use `--prepare-for-inference` argument to convert model to torch native format before `test` and `export` steps.

#### Validate your model checkpoint

To estimate the test scores of your trained model checkpoint use the following command:
`python main.py -m test --config=configs/ssd300_vgg_voc_int8.json --data <path_to_dataset> --resume <path_to_trained_model_checkpoint>`
If you want to validate an FP32 model checkpoint, make sure the compression algorithm settings are empty in the configuration file or `pretrained=True` is set.

**WARNING**: The samples use `torch.load` functionality for checkpoint loading which, in turn, uses pickle facilities by default which are known to be vulnerable to arbitrary code execution attacks. **Only load the data you trust**

#### Export compressed model

To export trained model to ONNX format use the following command:
`python main.py -m export --config configs/ssd300_vgg_voc_int8.json --data <path_to_dataset> --resume <path_to_compressed_model_checkpoint> --to-onnx=../../results/ssd300_int8.onnx`

Expand Down
18 changes: 12 additions & 6 deletions examples/torch/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
from pathlib import Path

import torch
from torch.utils import data

from examples.torch.common.argparser import parse_args
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils import data

from examples.torch.common import restricted_pickle_module
from examples.torch.common.argparser import get_common_argument_parser
from examples.torch.common.argparser import parse_args
from examples.torch.common.distributed import DistributedSampler
from examples.torch.common.example_logger import logger
from examples.torch.common.execution import get_execution_mode
Expand Down Expand Up @@ -57,8 +56,8 @@
from examples.torch.object_detection.layers.modules import MultiBoxLoss
from examples.torch.object_detection.model import build_ssd
from nncf.api.compression import CompressionStage
from nncf.config.utils import is_accuracy_aware_training
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.config.utils import is_accuracy_aware_training
from nncf.torch import create_compressed_model
from nncf.torch import load_state
from nncf.torch.dynamic_graph.graph_tracer import create_input_infos
Expand Down Expand Up @@ -210,6 +209,8 @@ def model_eval_fn(model):
log_common_mlflow_params(config)

if is_export_only:
if config.prepare_for_inference:
compression_ctrl.prepare_for_inference(make_model_copy=False)
compression_ctrl.export_model(config.to_onnx)
logger.info("Saved to {}".format(config.to_onnx))
return
Expand Down Expand Up @@ -257,17 +258,22 @@ def configure_optimizers_fn():

if 'test' in config.mode:
with torch.no_grad():
val_net = net
if config.prepare_for_inference:
val_net = compression_ctrl.prepare_for_inference(make_model_copy=True)
net.eval()
if config['ssd_params'].get('loss_inference', False):
model_loss = test_net(net, config.device, test_data_loader, distributed=config.distributed,
model_loss = test_net(val_net, config.device, test_data_loader, distributed=config.distributed,
loss_inference=True, criterion=criterion)
logger.info("Final model loss: {:.3f}".format(model_loss))
else:
mAp = test_net(net, config.device, test_data_loader, distributed=config.distributed)
mAp = test_net(val_net, config.device, test_data_loader, distributed=config.distributed)
if config.metrics_dump is not None:
write_metrics(mAp, config.metrics_dump)

if 'export' in config.mode:
if config.prepare_for_inference:
compression_ctrl.prepare_for_inference(make_model_copy=False)
compression_ctrl.export_model(config.to_onnx)
logger.info("Saved to {}".format(config.to_onnx))

Expand Down
18 changes: 13 additions & 5 deletions examples/torch/semantic_segmentation/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Semantic segmentation sample

This sample demonstrates DL model compression capabilities for semantic segmentation problem

## Features:

- UNet and ICNet with implementations as close as possible to the original papers
- Loaders for CamVid, Cityscapes (20-class), Mapillary Vistas(20-class), Pascal VOC (reuses the loader integrated into torchvision)
- Configuration file examples for sparsity, quantization, filter pruning and quantization with sparsity
Expand All @@ -11,7 +13,7 @@ This sample demonstrates DL model compression capabilities for semantic segmenta

## Installation

At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).
At this point it is assumed that you have already installed nncf. You can find information on downloading nncf [here](https://github.com/openvinotoolkit/nncf#user-content-installation).

To work with the sample you should install the corresponding Python package dependencies:

Expand All @@ -20,12 +22,15 @@ pip install -r examples/torch/requirements.txt
```

## Quantize FP32 pretrained model

This scenario demonstrates quantization with fine-tuning of UNet on Mapillary Vistas dataset.

#### Dataset preparation

- Obtain a copy of Mapillary Vistas train/val data [here](https://www.mapillary.com/dataset/vistas/)

#### Run semantic segmentation sample

- If you did not install the package then add the repository root folder to the `PYTHONPATH` environment variable
- Navigate to the `examples/torch/segmentation` folder
- (Optional) Before compressing a model, it is highly recommended checking the accuracy of the pretrained model, use the following command:
Expand All @@ -39,25 +44,28 @@ This scenario demonstrates quantization with fine-tuning of UNet on Mapillary Vi
--disable-compression
```
- Run the following command to start compression with fine-tuning on GPUs:
`python main.py -m train --config configs/unet_mapillary_int8.json --data <path_to_dataset> --weights <path_to_fp32_model_checkpoint>`
`python main.py -m train --config configs/unet_mapillary_int8.json --data <path_to_dataset> --weights <path_to_fp32_model_checkpoint>`
- Use `--prepare-for-inference` argument to convert model to torch native format before `test` and `export` steps.

It may take a few epochs to get the baseline accuracy results.

- Use `--multiprocessing-distributed` flag to run in the distributed mode.
- Use `--resume` flag with the path to a model from the previous experiment to resume training.
- Use `-b <number>` option to specify the total batch size across GPUs
- Use the `--weights` flag with the path to a compatible PyTorch checkpoint in order to load all matching weights from the checkpoint into the model - useful
if you need to start compression-aware training from a previously trained uncompressed (FP32) checkpoint instead of performing compression-aware training fr
om scratch.

if you need to start compression-aware training from a previously trained uncompressed (FP32) checkpoint instead of performing compression-aware training fr
om scratch.

#### Validate your model checkpoint

To estimate the test scores of your trained model checkpoint use the following command:
`python main.py -m test --config=configs/unet_mapillary_int8.json --resume <path_to_trained_model_checkpoint>`
If you want to validate an FP32 model checkpoint, make sure the compression algorithm settings are empty in the configuration file or `pretrained=True` is set.

**WARNING**: The samples use `torch.load` functionality for checkpoint loading which, in turn, uses pickle facilities by default which are known to be vulnerable to arbitrary code execution attacks. **Only load the data you trust**

#### Export compressed model

To export trained model to ONNX format use the following command:
`python main.py --mode export --config configs/unet_mapillary_int8.json --data <path_to_dataset> --resume <path_to_compressed_model_checkpoint> --to-onnx unet_int8.onnx`

Expand Down
Loading

0 comments on commit f79e9f7

Please sign in to comment.