Skip to content

Commit

Permalink
[PTQ][Tests] Add model reshape to calibrate.py (openvinotoolkit#1763)
Browse files Browse the repository at this point in the history
### Changes

`allow_reshape_input` Accuracy checker param now emulated in
`calibrate.py`

### Reason for changes

To enable models that require reshape or dynamic shapes before inference
on a calibration dataset

### Related tickets

108977

### Tests
Checked locally on bert-base-cased, fbcnn and human-pose-estimation
TODO

- [x] Check on all models with `allow_reshape_input`  param in ACConfig
  • Loading branch information
daniil-lyakhov committed May 11, 2023
1 parent b7b34a9 commit 9cf724f
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 9 deletions.
2 changes: 1 addition & 1 deletion nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_partial_shape_safe(node, port_id) -> int:
partial_shape = node.get_output_partial_shape(port_id)
if partial_shape.rank.is_dynamic or not partial_shape.all_non_negative:
raise RuntimeError(
f"Could not collect statistics for the node {node}" f"because its output shape rank is dynamic or negative"
f"Could not collect statistics for the node {node} because its output shape rank is dynamic or negative"
)
return partial_shape

Expand Down
112 changes: 104 additions & 8 deletions tests/openvino/tools/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
import os
from argparse import ArgumentParser
from collections import OrderedDict
from collections import defaultdict
from dataclasses import asdict
from enum import Enum
from itertools import islice
from typing import Iterable, Optional, TypeVar

import numpy as np
import openvino.runtime as ov
from openvino.runtime import Dimension
from openvino.runtime import PartialShape
from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import ModelEvaluator
from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator
from openvino.tools.pot.configs.config import Config
Expand Down Expand Up @@ -587,10 +591,83 @@ def get_nncf_algorithms_config(compression_config):
return nncf_algorithms


def get_allow_reshape_input(accuracy_checker_config) -> bool:
for model_config in accuracy_checker_config["models"]:
for launcher_config in model_config["launchers"]:
if "allow_reshape_input" in launcher_config:
return launcher_config["allow_reshape_input"]
return False


# pylint:disable=too-many-branches
def maybe_reshape_model(model, dataset, subset_size, input_to_tensor_name):
dataset_inputs_shapes = defaultdict(set)
for input_dict in islice(dataset.get_inference_data(), subset_size):
for name, tensor in input_dict.items():
dataset_inputs_shapes[name].add(tuple(tensor.shape))

model_inputs_shapes = {}
for input_output in model.inputs:
input_node = input_output.get_node()
model_inputs_shapes[input_to_tensor_name[input_node.friendly_name]] = tuple(input_node.partial_shape)

if len(dataset_inputs_shapes) != len(model_inputs_shapes):
raise RuntimeError(
f"Model inputs: {list(model_inputs_shapes.keys())}"
f" and dataset inputs {list(dataset_inputs_shapes.keys())} are not compatible"
)

for name in model_inputs_shapes:
if name not in dataset_inputs_shapes:
raise RuntimeError(
f"Model input {name} is not present in dataset inputs: {list(dataset_inputs_shapes.keys())}"
)

dynamic_dims = defaultdict(list)
reshaped_static_dims = defaultdict(list)
for name, shapes in dataset_inputs_shapes.items():
shapes = list(shapes)
if len(set(len(shape) for shape in shapes)) != 1 or len(model_inputs_shapes[name]) != len(shapes[0]):
raise RuntimeError("calibrate.py does not support dataset with dynamic ranks")

for idx in range(len(shapes[0])):
if len(shapes) == 1:
model_dim = model_inputs_shapes[name][idx]
if model_dim.is_static and model_dim.get_length() != shapes[0][idx]:
reshaped_static_dims[name].append(idx)

elif any(shapes[0][idx] != shape[idx] for shape in shapes[1:]):
dynamic_dims[name].append(idx)

if not any(any(dict_.values()) for dict_ in [dynamic_dims, reshaped_static_dims]):
return model

partial_shapes = {}
for name, shape in model_inputs_shapes.items():
dataset_first_shape = dataset_inputs_shapes[name].pop()
dims = []
for idx, d in enumerate(shape):
if idx in dynamic_dims[name]:
dim = Dimension(-1)
elif idx in reshaped_static_dims[name]:
dim = Dimension(dataset_first_shape[idx])
else:
if isinstance(d, Dimension):
dim = d
elif isinstance(d, tuple):
dim = Dimension(d[0], d[1])
else:
dim = Dimension(d)
dims.append(dim)
partial_shapes[name] = PartialShape(dims)
model.reshape(partial_shapes)
return model


# pylint: disable=protected-access
def quantize_model(xml_path, bin_path, accuracy_checcker_config, quantization_impl, quantization_parameters):
def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_impl, quantization_parameters):
ov_model = ov.Core().read_model(model=xml_path, weights=bin_path)
model_evaluator = create_model_evaluator(accuracy_checcker_config)
model_evaluator = create_model_evaluator(accuracy_checker_config)
model_evaluator.load_network([{"model": ov_model}])
model_evaluator.select_dataset("")

Expand All @@ -612,16 +689,26 @@ def transform_fn(data_item):
return input_data

calibration_dataset = nncf.Dataset(model_evaluator.dataset, transform_fn)

if get_allow_reshape_input(accuracy_checker_config):
ov_model = maybe_reshape_model(
ov_model,
calibration_dataset,
quantization_parameters.get("subset_size", 300),
model_evaluator.launcher.input_to_tensor_name,
)
model_evaluator.load_network([{"model": ov_model}])

quantized_model = nncf.quantize(ov_model, calibration_dataset, **quantization_parameters)
return quantized_model


# pylint: disable=protected-access
def quantize_model_with_accuracy_control(
xml_path: str, bin_path: str, accuracy_checcker_config, quantization_impl: str, quantization_parameters
xml_path: str, bin_path: str, accuracy_checker_config, quantization_impl: str, quantization_parameters
):
ov_model = ov.Core().read_model(xml_path, bin_path)
model_evaluator = create_model_evaluator(accuracy_checcker_config)
model_evaluator = create_model_evaluator(accuracy_checker_config)
model_evaluator.load_network_from_ir([{"model": xml_path, "weights": bin_path}])
model_evaluator.select_dataset("")

Expand All @@ -633,9 +720,18 @@ def transform_fn(data_item):
calibration_dataset = nncf.Dataset(model_evaluator.dataset, transform_fn)
validation_dataset = nncf.Dataset(list(range(model_evaluator.dataset.full_size)))

metric_name = accuracy_checcker_config["models"][0]["datasets"][0]["metrics"][0].get("name", None)
if get_allow_reshape_input(accuracy_checker_config):
ov_model = maybe_reshape_model(
ov_model,
calibration_dataset,
quantization_parameters.get("subset_size", 300),
model_evaluator.launcher.input_to_tensor_name,
)
model_evaluator.load_network([{"model": ov_model}])

metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0].get("name", None)
if metric_name is None:
metric_name = accuracy_checcker_config["models"][0]["datasets"][0]["metrics"][0]["type"]
metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0]["type"]
validation_fn = ACValidationFunction(model_evaluator, metric_name)

name_to_quantization_impl_map = {
Expand Down Expand Up @@ -667,7 +763,7 @@ def main():
config.configure_params()

xml_path, bin_path = get_model_paths(config.model)
accuracy_checcker_config = get_accuracy_checker_config(config.engine)
accuracy_checker_config = get_accuracy_checker_config(config.engine)
nncf_algorithms_config = get_nncf_algorithms_config(config.compression)

set_log_file(f"{args.output_dir}/log.txt")
Expand All @@ -685,7 +781,7 @@ def main():
quantize_model_arguments = {
"xml_path": xml_path,
"bin_path": bin_path,
"accuracy_checcker_config": accuracy_checcker_config,
"accuracy_checker_config": accuracy_checker_config,
"quantization_impl": args.impl,
"quantization_parameters": algo_config["parameters"],
}
Expand Down

0 comments on commit 9cf724f

Please sign in to comment.