Skip to content

Commit

Permalink
[Torch][PTQ] Examples are updated for the new PTQ TORCH backend (#2246)
Browse files Browse the repository at this point in the history
### Changes

- Do not filter constant nodes for torch backend in the inference graph
- Fix version in requarements.txt for examples of
post_training_quantization
- for ssd300_vgg16 is not available to use torch 2.1.0 (failed on export
to onnx Unsupported: ONNX export of operator get_pool_ceil_padding,
tracing is not supporting too)
    - Update metrics
- Add to PTEngine convert inputs to model's device to sync behavior with
`create_compress_model`
- Mobilenet_v2 example converting PyTorch model to IR by tracing
(without onnx).
- nncf.quantize for PyTorch works with copy of the target model

### Reason for changes

To make PTQ work properly with disconnected graphs (like in
[example](https://github.com/openvinotoolkit/nncf/blob/develop/examples/post_training_quantization/torch/ssd300_vgg16/main.py))

### Related tickets
124417

### Tests

test_examples build 128

---------

Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
  • Loading branch information
daniil-lyakhov and AlexanderDokuchaev committed Nov 10, 2023
1 parent d5e3942 commit 1493149
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 66 deletions.
39 changes: 11 additions & 28 deletions examples/post_training_quantization/torch/mobilenet_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
import subprocess
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np
import openvino as ov
Expand All @@ -24,9 +24,9 @@
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from tqdm import tqdm

import nncf
from nncf.common.logging.track_progress import track

ROOT = Path(__file__).parent.resolve()
CHECKPOINT_URL = "https://huggingface.co/alexsu52/mobilenet_v2_imagenette/resolve/main/pytorch_model.bin"
Expand All @@ -53,7 +53,7 @@ def validate(model: ov.Model, val_loader: torch.utils.data.DataLoader) -> float:
compiled_model = ov.compile_model(model)
output = compiled_model.outputs[0]

for images, target in tqdm(val_loader):
for images, target in track(val_loader, description="Validating"):
pred = compiled_model(images)[output]
predictions.append(np.argmax(pred, axis=1))
references.append(target)
Expand Down Expand Up @@ -84,9 +84,9 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl
bin_size /= 1024
model_size = xml_size + bin_size
if verbose:
print(f"Model graph (xml): {xml_size:.3f} Mb")
print(f"Model weights (bin): {bin_size:.3f} Mb")
print(f"Model size: {model_size:.3f} Mb")
print(f"Model graph (xml): {xml_size:.3f} {m_type}")
print(f"Model weights (bin): {bin_size:.3f} {m_type}")
print(f"Model size: {model_size:.3f} {m_type}")
return model_size


Expand Down Expand Up @@ -123,7 +123,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl
# >> model(transform_fn(data_item))


def transform_fn(data_item):
def transform_fn(data_item: Tuple[torch.Tensor, int]) -> torch.Tensor:
images, _ = data_item
return images

Expand All @@ -149,28 +149,11 @@ def transform_fn(data_item):
# Benchmark performance, calculate compression rate and validate accuracy

dummy_input = torch.randn(1, 3, 224, 224)

fp32_onnx_path = f"{ROOT}/mobilenet_v2_fp32.onnx"
torch.onnx.export(
torch_model.cpu(),
dummy_input,
fp32_onnx_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "-1"}},
)
ov_model = mo.convert_model(fp32_onnx_path)

int8_onnx_path = f"{ROOT}/mobilenet_v2_int8.onnx"
torch.onnx.export(
torch_quantized_model.cpu(),
dummy_input,
int8_onnx_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "-1"}},
ov_input_shape = (-1, 3, 224, 224)
ov_model = mo.convert_model(torch_model.cpu(), example_input=dummy_input, input_shape=ov_input_shape)
ov_quantized_model = mo.convert_model(
torch_quantized_model.cpu(), example_input=dummy_input, input_shape=ov_input_shape
)
ov_quantized_model = mo.convert_model(int8_onnx_path)

fp32_ir_path = f"{ROOT}/mobilenet_v2_fp32.xml"
ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torchvision>=0.10.1,<0.16
tqdm
scikit-learn
fastdownload
fastdownload==0.0.7
openvino-dev==2023.1
onnx
scikit-learn
torch==2.1.0
torchvision==0.16.0
21 changes: 11 additions & 10 deletions examples/post_training_quantization/torch/ssd300_vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import subprocess
from pathlib import Path
from typing import Callable, Tuple, Dict

# nncf.torch must be imported before torchvision
import nncf
Expand All @@ -27,7 +28,7 @@
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.ssd import SSD
from torchvision.models.detection.ssd import GeneralizedRCNNTransform
from tqdm import tqdm
from nncf.common.logging.track_progress import track

ROOT = Path(__file__).parent.resolve()
DATASET_URL = "https://ultralytics.com/assets/coco128.zip"
Expand All @@ -49,9 +50,9 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl
bin_size /= 1024
model_size = xml_size + bin_size
if verbose:
print(f"Model graph (xml): {xml_size:.3f} Mb")
print(f"Model weights (bin): {bin_size:.3f} Mb")
print(f"Model size: {model_size:.3f} Mb")
print(f"Model graph (xml): {xml_size:.3f} {m_type}")
print(f"Model weights (bin): {bin_size:.3f} {m_type}")
print(f"Model size: {model_size:.3f} {m_type}")
return model_size


Expand All @@ -73,15 +74,15 @@ class COCO128Dataset(torch.utils.data.Dataset):
61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90
] # fmt: skip

def __init__(self, data_path, transform):
def __init__(self, data_path: str, transform: Callable):
super().__init__()
self.transform = transform
self.data_path = Path(data_path)
self.images_path = self.data_path / "images" / "train2017"
self.labels_path = self.data_path / "labels" / "train2017"
self.image_ids = sorted(map(lambda p: int(p.stem), self.images_path.glob("*.jpg")))

def __getitem__(self, item):
def __getitem__(self, item: int) -> Tuple[torch.Tensor, Dict]:
image_id = self.image_ids[item]

img = Image.open(self.images_path / f"{image_id:012d}.jpg")
Expand All @@ -106,16 +107,16 @@ def __getitem__(self, item):
img, target = self.transform(img, target)
return img, target

def __len__(self):
def __len__(self) -> int:
return len(self.image_ids)


def validate(model, dataset, device):
def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.device):
model.to(device)
model.eval()
metric = MeanAveragePrecision()
with torch.no_grad():
for img, target in tqdm(dataset, desc="Validating"):
for img, target in track(dataset, description="Validating"):
prediction = model(img.to(device)[None])[0]
for k in prediction.keys():
prediction[k] = prediction[k].to(torch.device("cpu"))
Expand All @@ -124,7 +125,7 @@ def validate(model, dataset, device):
return computed_metrics["map_50"]


def transform_fn(data_item):
def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor:
# Skip label and add a batch dimension to an image tensor
images, _ = data_item
return images[None]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
fastdownload
fastdownload==0.0.7
onnx==1.13.1
openvino-dev==2023.1
pycocotools==2.0.7
torch==2.0.1 # ssd300_vgg16 can not be exported with 2.1.0, reference: https://github.com/pytorch/pytorch/issues/113155
torchmetrics==1.0.1
pycocotools
torchvision~=0.15.1
tqdm
onnx
torchvision==0.15.2
1 change: 1 addition & 0 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def _get_quantization_target_points(
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
self._backend_entity.read_variable_metatypes,
nncf_graph_contains_constants=backend != BackendType.TORCH,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
5 changes: 4 additions & 1 deletion nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def transform_to_inference_graph(
shapeof_metatypes: List[OperatorMetatype],
dropout_metatypes: List[OperatorMetatype],
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
nncf_graph_contains_constants: bool = True,
) -> NNCFGraph:
"""
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
Expand All @@ -32,11 +33,13 @@ def transform_to_inference_graph(
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not.
:return: NNCFGraph in the inference style.
"""
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
filter_constant_nodes(nncf_graph, read_variable_metatypes)
if nncf_graph_contains_constants:
filter_constant_nodes(nncf_graph, read_variable_metatypes)
return nncf_graph


Expand Down
10 changes: 10 additions & 0 deletions nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from torch import nn

from nncf.common.engine import Engine
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_tensor


class PTEngine(Engine):
Expand All @@ -31,6 +34,7 @@ def __init__(self, model: nn.Module):

self._model = model
self._model.eval()
self._device = get_model_device(model)

def infer(
self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]
Expand All @@ -41,6 +45,12 @@ def infer(
:param input_data: Inputs for the model.
:return: Model outputs.
"""

def send_to_device(tensor):
return tensor.to(self._device)

input_data = objwalk(input_data, is_tensor, send_to_device)

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
Expand Down
4 changes: 3 additions & 1 deletion nncf/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -68,7 +69,8 @@ def quantize_impl(
if target_device == TargetDevice.CPU_SPR:
raise RuntimeError("target_device == CPU_SPR is not supported")

nncf_network = create_nncf_network_ptq(model.eval(), calibration_dataset)
copied_model = deepcopy(model)
nncf_network = create_nncf_network_ptq(copied_model.eval(), calibration_dataset)

quantization_algorithm = PostTrainingQuantization(
preset=preset,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /ReadVariable_0" [id=1, type=ReadVariable];
"4 /Conv_0" [id=4, type=Conv];
"6 /Conv2_0" [id=6, type=Conv2];
"7 /Add_0" [id=7, type=Add];
"8 /Final_node_0" [id=8, type=Final_node];
"0 /Input_1_0" -> "4 /Conv_0";
"1 /ReadVariable_0" -> "7 /Add_0";
"6 /Conv2_0" -> "7 /Add_0";
"7 /Add_0" -> "8 /Final_node_0";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /ReadVariable_0" [id=1, type=ReadVariable];
"2 /Weights_0" [id=2, type=Weights];
"3 /AnyNodeBetweenWeightAndConv_0" [id=3, type=AnyNodeBetweenWeightAndConv];
"4 /Conv_0" [id=4, type=Conv];
"5 /Weights2_0" [id=5, type=Weights2];
"6 /Conv2_0" [id=6, type=Conv2];
"7 /Add_0" [id=7, type=Add];
"8 /Final_node_0" [id=8, type=Final_node];
"0 /Input_1_0" -> "4 /Conv_0";
"1 /ReadVariable_0" -> "7 /Add_0";
"2 /Weights_0" -> "3 /AnyNodeBetweenWeightAndConv_0";
"3 /AnyNodeBetweenWeightAndConv_0" -> "4 /Conv_0";
"5 /Weights2_0" -> "6 /Conv2_0";
"6 /Conv2_0" -> "7 /Add_0";
"7 /Add_0" -> "8 /Final_node_0";
}
29 changes: 24 additions & 5 deletions tests/common/quantization/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import pytest

from nncf.quantization.passes import filter_constant_nodes
from nncf.quantization.passes import remove_nodes_and_reconnect_graph
from tests.post_training.test_templates.models import NNCFGraphDropoutRemovingCase
from tests.post_training.test_templates.models import NNCFGraphToTestConstantFiltering
from tests.shared.nx_graph import compare_nx_graph_with_reference
from tests.shared.paths import TEST_ROOT

Expand All @@ -28,13 +30,14 @@ class TestModes(Enum):
WRONG_PARALLEL_EDGES = "wrong_parallel_edges"


def _check_graphs(dot_file_name, nncf_graph) -> None:
nx_graph = nncf_graph.get_graph_for_structure_analysis()
path_to_dot = DATA_ROOT / dot_file_name
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)


@pytest.mark.parametrize("mode", [TestModes.VALID, TestModes.WRONG_TENSOR_SHAPE, TestModes.WRONG_PARALLEL_EDGES])
def test_remove_nodes_and_reconnect_graph(mode: TestModes):
def _check_graphs(dot_file_name, nncf_graph) -> None:
nx_graph = nncf_graph.get_graph_for_structure_analysis()
path_to_dot = DATA_ROOT / dot_file_name
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)

dot_reference_path_before = Path("passes") / "dropout_synthetic_model_before.dot"
dot_reference_path_after = Path("passes") / "dropout_synthetic_model_after.dot"
dropout_metatype = "DROPOUT_METATYPE"
Expand All @@ -52,3 +55,19 @@ def _check_graphs(dot_file_name, nncf_graph) -> None:
_check_graphs(dot_reference_path_before, nncf_graph)
remove_nodes_and_reconnect_graph(nncf_graph, [dropout_metatype])
_check_graphs(dot_reference_path_after, nncf_graph)


@pytest.mark.xfail
def test_filter_constant_nodes():
dot_reference_path_before = Path("passes") / "test_constant_filtering_model_before.dot"
dot_reference_path_after = Path("passes") / "test_constant_filtering_model_after.dot"

constant_metatype = "CONSTANT_METATYPE"
read_variable_metatype = "READ_VARIABLE_METATYPE"

nncf_graph = NNCFGraphToTestConstantFiltering(constant_metatype, read_variable_metatype).nncf_graph
_check_graphs(dot_reference_path_before, nncf_graph)
filter_constant_nodes(
nncf_graph, read_variable_metatypes=[read_variable_metatype], constant_nodes_metatypes=[constant_metatype]
)
_check_graphs(dot_reference_path_after, nncf_graph)
8 changes: 4 additions & 4 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_metrics": {
"fp32_top1": 0.9864968152866243,
"int8_top1": 0.9829299363057324,
"int8_top1": 0.9836942675159236,
"accuracy_drop": 0.0035668789808918078
},
"performance_metrics": {
Expand All @@ -129,8 +129,8 @@
"requirements": "examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt",
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_metrics": {
"fp32_mAP": 0.5232756733894348,
"int8_mAP": 0.5140125155448914,
"fp32_mAP": 0.5228869318962097,
"int8_mAP": 0.5148677825927734,
"accuracy_drop": 0.009263157844543457
},
"performance_metrics": {
Expand All @@ -144,4 +144,4 @@
"model_compression_rate": 3.8631822183889652
}
}
}
}
6 changes: 3 additions & 3 deletions tests/cross_fw/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

ACCURACY_METRICS = "accuracy_metrics"
MODEL_SIZE_METRICS = "model_size_metrics"
PERFORMNACE_METRICS = "performance_metrics"
PERFORMANCE_METRICS = "performance_metrics"


def example_test_cases():
Expand Down Expand Up @@ -85,6 +85,6 @@ def test_examples(
for name, value in example_params[MODEL_SIZE_METRICS].items():
assert measured_metrics[name] == pytest.approx(value, rel=MODEL_SIZE_RELATIVE_TOLERANCE)

if is_check_performance and PERFORMNACE_METRICS in example_params:
for name, value in example_params[PERFORMNACE_METRICS].items():
if is_check_performance and PERFORMANCE_METRICS in example_params:
for name, value in example_params[PERFORMANCE_METRICS].items():
assert measured_metrics[name] == pytest.approx(value, rel=PERFORMANCE_RELATIVE_TOLERANCE)
Loading

0 comments on commit 1493149

Please sign in to comment.