Skip to content

Commit

Permalink
[Test][PTQ] Fix transform_fn for AA in calibrate.py (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#1786)

### Changes

AA `transform_fn` aligned with standard quantize `tranform_fn` 

### Reason for changes

To align AA and standard quantize in `calibrate.py`

### Related tickets



### Tests
  • Loading branch information
daniil-lyakhov committed May 17, 2023
1 parent feb5b04 commit 2d76d7a
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions tests/openvino/tools/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,18 @@ def maybe_reshape_model(model, dataset, subset_size, input_to_tensor_name):


# pylint: disable=protected-access
def get_transform_fn(model_evaluator: ModelEvaluator):
def transform_fn(data_item):
_, batch_annotation, batch_input, _ = data_item
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
input_data = {}
for name, value in filled_inputs[0].items():
input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value
return input_data

return transform_fn


def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_impl, quantization_parameters):
ov_model = ov.Core().read_model(model=xml_path, weights=bin_path)
model_evaluator = create_model_evaluator(accuracy_checker_config)
Expand All @@ -680,14 +692,7 @@ def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_imp
raise NotImplementedError()
quantization_parameters["advanced_parameters"] = advanced_parameters

def transform_fn(data_item):
_, batch_annotation, batch_input, _ = data_item
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
input_data = {}
for name, value in filled_inputs[0].items():
input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value
return input_data

transform_fn = get_transform_fn(model_evaluator)
calibration_dataset = nncf.Dataset(model_evaluator.dataset, transform_fn)

if get_allow_reshape_input(accuracy_checker_config):
Expand All @@ -703,7 +708,6 @@ def transform_fn(data_item):
return quantized_model


# pylint: disable=protected-access
def quantize_model_with_accuracy_control(
xml_path: str, bin_path: str, accuracy_checker_config, quantization_impl: str, quantization_parameters
):
Expand All @@ -712,11 +716,7 @@ def quantize_model_with_accuracy_control(
model_evaluator.load_network_from_ir([{"model": xml_path, "weights": bin_path}])
model_evaluator.select_dataset("")

def transform_fn(data_item):
_, batch_annotation, batch_input, _ = data_item
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
return filled_inputs[0]

transform_fn = get_transform_fn(model_evaluator)
calibration_dataset = nncf.Dataset(model_evaluator.dataset, transform_fn)
validation_dataset = nncf.Dataset(list(range(model_evaluator.dataset.full_size)))

Expand Down

0 comments on commit 2d76d7a

Please sign in to comment.