diff --git a/examples/resnet/requirements.txt b/examples/resnet/requirements.txt index 9be1dabe9..71e30b9b1 100644 --- a/examples/resnet/requirements.txt +++ b/examples/resnet/requirements.txt @@ -1,7 +1,7 @@ azure-ai-ml azure-identity azureml-fsspec -onnxruntime +onnxruntime<=1.15.1 pytorch-lightning scipy tabulate diff --git a/examples/test/test_resnet_vitis_ai_ptq_cpu.py b/examples/test/test_resnet_vitis_ai_ptq_cpu.py index 488c9a0ed..5d16c4293 100644 --- a/examples/test/test_resnet_vitis_ai_ptq_cpu.py +++ b/examples/test/test_resnet_vitis_ai_ptq_cpu.py @@ -6,6 +6,8 @@ from pathlib import Path import pytest +from onnxruntime import __version__ as OrtVersion +from packaging import version from utils import check_output, patch_config from olive.common.utils import retry_func, run_subprocess @@ -30,6 +32,10 @@ def setup(): @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system", "aml_system"]) @pytest.mark.parametrize("olive_json", ["resnet_vitis_ai_ptq_cpu.json"]) +@pytest.mark.skipif( + version.parse(OrtVersion) >= version.parse("1.16.0"), + reason="VitisAIQuantization is not supported in ORT 1.16.0 with TensorsData", +) def test_resnet(search_algorithm, execution_order, system, olive_json): # TODO: add gpu e2e test from olive.workflows import run as olive_run diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index cbb33d84d..de2e74988 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,7 +1,7 @@ neural-compressor onnx==1.14.0 onnxruntime>=1.15.0 -onnxruntime-extensions>=0.8.0 +onnxruntime-extensions==0.8.0 tabulate torch>=1.13.1 transformers>=4.23.1 diff --git a/olive/passes/onnx/quant_pre_process.py b/olive/passes/onnx/quant_pre_process.py new file mode 100644 index 000000000..08f5d6ec9 --- /dev/null +++ b/olive/passes/onnx/quant_pre_process.py @@ -0,0 +1,164 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft, Intel Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import shutil +import tempfile +import traceback +from contextlib import contextmanager +from pathlib import Path +from typing import Optional + +import onnx +import onnxruntime +from onnxruntime.quantization.quant_utils import add_pre_process_metadata +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + +logger = logging.getLogger(__name__) + + +def quant_pre_process( + input_model_path: str, + output_model_path: str, + skip_optimization: bool = False, + skip_onnx_shape: bool = False, + skip_symbolic_shape: bool = False, + auto_merge: bool = False, + int_max: int = 2**31 - 1, + guess_output_rank: bool = False, + verbose: int = 0, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: Optional[str] = None, + external_data_size_threshold: int = 1024, +) -> None: + """Shape inference and model optimization, in preparation for quantization. + + Args: + input_model_path: Path to the input model file") + output_model_path: Path to the output model file + skip_optimization: Skip model optimization step if true. This may result in ONNX shape + inference failure for some models. + skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective + with transformer based models. Skipping all shape inferences may + reduce the effectiveness of quantization, as a tensor with unknown + shape can not be quantized. + skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most + effective with transformer based models. Skipping all shape + inferences may reduce the effectiveness of quantization, as a tensor + with unknown shape can not be quantized. + auto_merge: For symbolic shape inference, automatically merge symbolic dims when + conflict happens. + int_max: For symbolic shape inference, specify the maximum value for integer to be + treated as boundless for ops like slice + guess_output_rank: Guess output rank to be the same as input 0 for unknown ops + verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed + save_as_external_data: Saving an ONNX model to external data + all_tensors_to_one_file: Saving all the external data to one file + external_data_location: The file location to save the external file + external_data_size_threshold: The size threshold for external data + """ + with TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir: + temp_path = Path(quant_tmp_dir) + model = None + + if not skip_symbolic_shape: + logger.info("Performing symbolic shape inference...") + model = SymbolicShapeInference.infer_shapes( + onnx.load(input_model_path), + int_max, + auto_merge, + guess_output_rank, + verbose, + ) + + if not skip_optimization: + # Use ORT optimizers (native code) to optimize model + if not skip_symbolic_shape: + # Need to save the inferenced model to file so as to run the optimizer + input_model_path = str(temp_path / "symbolic_shape_inferred.onnx") + if save_as_external_data: + onnx.save_model( + model, + input_model_path, + save_as_external_data=True, + all_tensors_to_one_file=all_tensors_to_one_file, + size_threshold=external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(model, input_model_path) + model = None + + opt_model_path = str(temp_path / "optimized.onnx") + try: + sess_option = onnxruntime.SessionOptions() + sess_option.optimized_model_filepath = opt_model_path + sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC + _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + except Exception: + logger.error( + "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'." + ) + logger.error(traceback.format_exc()) + + input_model_path = opt_model_path + + if not skip_onnx_shape: + # ONNX shape inference. + # According to docs, infer_shapes_path should be used for 2G+ models. + # If the skip optimization is specified, we could be dealing with a + # large model. So be on the safe side, save the model + if model is not None: + input_model_path = str(temp_path / "symbolic_shape_inferred.onnx") + if save_as_external_data: + onnx.save_model( + model, + input_model_path, + save_as_external_data=True, + all_tensors_to_one_file=all_tensors_to_one_file, + size_threshold=external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(model, input_model_path) + model = None + + inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx") + onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path) + model = onnx.load(inferred_model_path) + + if model is None: + model = onnx.load(input_model_path) + + add_pre_process_metadata(model) + + if save_as_external_data: + onnx.save_model( + model, + output_model_path, + save_as_external_data=True, + all_tensors_to_one_file=all_tensors_to_one_file, + location=external_data_location, + size_threshold=external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(model, output_model_path) + + +@contextmanager +def TemporaryDirectory(**kwargs): + # TODO: this is a workaround for issue https://github.com/microsoft/onnxruntime/issues/17627 + # on Windows. + name = tempfile.mkdtemp(**kwargs) + try: + yield name + finally: + try: + shutil.rmtree(name) + except OSError: + logger.warning(f"Failed to remove: {name}", exc_info=True) diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index 42fa4b22e..9a6e2b708 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -436,9 +436,10 @@ def _run_for_config( return model_proto_to_olive_model(onnx_model, output_model_path, config) def _quant_preprocess(self, model: ONNXModel, output_model_path: Union[str, Path]) -> ONNXModel: - from onnxruntime.quantization.preprocess import quant_pre_process + from olive.passes.onnx.quant_pre_process import quant_pre_process try: + # TODO: use ORT version once the Windows issue is fixed quant_pre_process( input_model_path=model.model_path, output_model_path=str(output_model_path), @@ -451,7 +452,9 @@ def _quant_preprocess(self, model: ONNXModel, output_model_path: Union[str, Path # there are some problems with the path to where the external data is saved # need to find out why before enabling this - logger.warning(f"Failed to run quantization preprocessing with error of {e}. Using original model.") + logger.warning( + f"Failed to run quantization preprocessing with error of {e}. Using original model.", exc_info=True + ) # save original model to output path onnx_model = onnx.load(model.model_path) model_proto_to_file( diff --git a/olive/passes/onnx/vitis_ai_quantization.py b/olive/passes/onnx/vitis_ai_quantization.py index 35e350fcc..5d4924f06 100644 --- a/olive/passes/onnx/vitis_ai_quantization.py +++ b/olive/passes/onnx/vitis_ai_quantization.py @@ -9,8 +9,6 @@ from typing import Any, Callable, Dict, Union import onnx -from onnxruntime.quantization.preprocess import quant_pre_process -from onnxruntime.quantization.quant_utils import QuantFormat, QuantType from olive.cache import get_local_path_from_root from olive.common.utils import hash_string @@ -18,8 +16,6 @@ from olive.model import ONNXModel from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_file, model_proto_to_olive_model -from olive.passes.onnx.vitis_ai import quantize_static -from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod from olive.passes.pass_config import ParamCategory, PassConfigParam from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS, LocalFile from olive.strategy.search_parameter import Boolean, Categorical, Conditional @@ -258,6 +254,11 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa def _run_for_config( self, model: ONNXModel, data_root: str, config: Dict[str, Any], output_model_path: str ) -> ONNXModel: + from onnxruntime.quantization.quant_utils import QuantFormat, QuantType + + from olive.passes.onnx.vitis_ai import quantize_static + from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod + # start with a copy of the config run_config = deepcopy(config) @@ -360,6 +361,8 @@ def _run_for_config( return model_proto_to_olive_model(onnx_model, output_model_path, config) def _quant_preprocess(self, model: ONNXModel, output_model_path: str) -> ONNXModel: + from olive.passes.onnx.quant_pre_process import quant_pre_process + try: quant_pre_process(input_model_path=model.model_path, output_model_path=output_model_path, auto_merge=True) except Exception as e: diff --git a/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py b/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py index b18ce5cf1..f99db13dc 100644 --- a/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py +++ b/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py @@ -7,7 +7,10 @@ from test.unit_test.utils import get_onnx_model import numpy as np +import pytest +from onnxruntime import __version__ as OrtVersion from onnxruntime.quantization.calibrate import CalibrationDataReader +from packaging import version from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.vitis_ai_quantization import VitisAIQuantization @@ -34,6 +37,10 @@ def dummy_calibration_reader(data_dir=None, batch_size=1, *args, **kwargs): return RandomDataReader() +@pytest.mark.skipif( + version.parse(OrtVersion) >= version.parse("1.16.0"), + reason="VitisAIQuantization is not supported in ORT 1.16.0 with TensorsData", +) def test_vitis_ai_quantization_pass(tmp_path): # setup input_model = get_onnx_model()