Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnxruntime 1.16 support #584

Merged
merged 10 commits into from
Sep 20, 2023
2 changes: 1 addition & 1 deletion examples/resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
azure-ai-ml
azure-identity
azureml-fsspec
onnxruntime
onnxruntime<=1.15.1
pytorch-lightning
scipy
tabulate
Expand Down
6 changes: 6 additions & 0 deletions examples/test/test_resnet_vitis_ai_ptq_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path

import pytest
from onnxruntime import __version__ as OrtVersion
from packaging import version
from utils import check_output, patch_config

from olive.common.utils import retry_func, run_subprocess
Expand All @@ -30,6 +32,10 @@ def setup():
@pytest.mark.parametrize("execution_order", ["pass-by-pass"])
@pytest.mark.parametrize("system", ["local_system", "aml_system"])
@pytest.mark.parametrize("olive_json", ["resnet_vitis_ai_ptq_cpu.json"])
@pytest.mark.skipif(
guotuofeng marked this conversation as resolved.
Show resolved Hide resolved
version.parse(OrtVersion) >= version.parse("1.16.0"),
reason="VitisAIQuantization is not supported in ORT 1.16.0 with TensorsData",
)
def test_resnet(search_algorithm, execution_order, system, olive_json):
# TODO: add gpu e2e test
from olive.workflows import run as olive_run
Expand Down
2 changes: 1 addition & 1 deletion examples/whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
neural-compressor
onnx==1.14.0
onnxruntime>=1.15.0
onnxruntime-extensions>=0.8.0
onnxruntime-extensions==0.8.0
tabulate
torch>=1.13.1
transformers>=4.23.1
164 changes: 164 additions & 0 deletions olive/passes/onnx/quant_pre_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import logging
import shutil
import tempfile
import traceback
from contextlib import contextmanager
from pathlib import Path
from typing import Optional

import onnx
import onnxruntime
from onnxruntime.quantization.quant_utils import add_pre_process_metadata
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

logger = logging.getLogger(__name__)


def quant_pre_process(
input_model_path: str,
output_model_path: str,
skip_optimization: bool = False,
skip_onnx_shape: bool = False,
skip_symbolic_shape: bool = False,
auto_merge: bool = False,
int_max: int = 2**31 - 1,
guess_output_rank: bool = False,
verbose: int = 0,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
external_data_location: Optional[str] = None,
external_data_size_threshold: int = 1024,
) -> None:
"""Shape inference and model optimization, in preparation for quantization.

Args:
input_model_path: Path to the input model file")
output_model_path: Path to the output model file
skip_optimization: Skip model optimization step if true. This may result in ONNX shape
inference failure for some models.
skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
with transformer based models. Skipping all shape inferences may
reduce the effectiveness of quantization, as a tensor with unknown
shape can not be quantized.
skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
effective with transformer based models. Skipping all shape
inferences may reduce the effectiveness of quantization, as a tensor
with unknown shape can not be quantized.
auto_merge: For symbolic shape inference, automatically merge symbolic dims when
conflict happens.
int_max: For symbolic shape inference, specify the maximum value for integer to be
treated as boundless for ops like slice
guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
save_as_external_data: Saving an ONNX model to external data
all_tensors_to_one_file: Saving all the external data to one file
external_data_location: The file location to save the external file
external_data_size_threshold: The size threshold for external data
"""
with TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
temp_path = Path(quant_tmp_dir)
model = None

if not skip_symbolic_shape:
logger.info("Performing symbolic shape inference...")
model = SymbolicShapeInference.infer_shapes(
onnx.load(input_model_path),
int_max,
auto_merge,
guess_output_rank,
verbose,
)

if not skip_optimization:
# Use ORT optimizers (native code) to optimize model
if not skip_symbolic_shape:
# Need to save the inferenced model to file so as to run the optimizer
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
model = None

opt_model_path = str(temp_path / "optimized.onnx")
try:
sess_option = onnxruntime.SessionOptions()
sess_option.optimized_model_filepath = opt_model_path
sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
_ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"])
except Exception:
logger.error(
"ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
)
logger.error(traceback.format_exc())

input_model_path = opt_model_path

if not skip_onnx_shape:
# ONNX shape inference.
# According to docs, infer_shapes_path should be used for 2G+ models.
# If the skip optimization is specified, we could be dealing with a
# large model. So be on the safe side, save the model
if model is not None:
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
model = None

inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path)
model = onnx.load(inferred_model_path)

if model is None:
model = onnx.load(input_model_path)

add_pre_process_metadata(model)

if save_as_external_data:
onnx.save_model(
model,
output_model_path,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
location=external_data_location,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, output_model_path)


@contextmanager
def TemporaryDirectory(**kwargs):
# TODO: this is a workaround for issue https://github.com/microsoft/onnxruntime/issues/17627
# on Windows.
name = tempfile.mkdtemp(**kwargs)
try:
yield name
finally:
try:
shutil.rmtree(name)
except OSError:
logger.warning(f"Failed to remove: {name}", exc_info=True)
7 changes: 5 additions & 2 deletions olive/passes/onnx/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,10 @@ def _run_for_config(
return model_proto_to_olive_model(onnx_model, output_model_path, config)

def _quant_preprocess(self, model: ONNXModel, output_model_path: Union[str, Path]) -> ONNXModel:
from onnxruntime.quantization.preprocess import quant_pre_process
from olive.passes.onnx.quant_pre_process import quant_pre_process

try:
# TODO: use ORT version once the Windows issue is fixed
quant_pre_process(
input_model_path=model.model_path,
output_model_path=str(output_model_path),
Expand All @@ -451,7 +452,9 @@ def _quant_preprocess(self, model: ONNXModel, output_model_path: Union[str, Path
# there are some problems with the path to where the external data is saved
# need to find out why before enabling this

logger.warning(f"Failed to run quantization preprocessing with error of {e}. Using original model.")
logger.warning(
f"Failed to run quantization preprocessing with error of {e}. Using original model.", exc_info=True
)
# save original model to output path
onnx_model = onnx.load(model.model_path)
model_proto_to_file(
Expand Down
11 changes: 7 additions & 4 deletions olive/passes/onnx/vitis_ai_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@
from typing import Any, Callable, Dict, Union

import onnx
from onnxruntime.quantization.preprocess import quant_pre_process
from onnxruntime.quantization.quant_utils import QuantFormat, QuantType

from olive.cache import get_local_path_from_root
from olive.common.utils import hash_string
from olive.hardware import AcceleratorSpec
from olive.model import ONNXModel
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_file, model_proto_to_olive_model
from olive.passes.onnx.vitis_ai import quantize_static
from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod
from olive.passes.pass_config import ParamCategory, PassConfigParam
from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS, LocalFile
from olive.strategy.search_parameter import Boolean, Categorical, Conditional
Expand Down Expand Up @@ -258,6 +254,11 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa
def _run_for_config(
self, model: ONNXModel, data_root: str, config: Dict[str, Any], output_model_path: str
) -> ONNXModel:
from onnxruntime.quantization.quant_utils import QuantFormat, QuantType

from olive.passes.onnx.vitis_ai import quantize_static
from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod

# start with a copy of the config
run_config = deepcopy(config)

Expand Down Expand Up @@ -360,6 +361,8 @@ def _run_for_config(
return model_proto_to_olive_model(onnx_model, output_model_path, config)

def _quant_preprocess(self, model: ONNXModel, output_model_path: str) -> ONNXModel:
from olive.passes.onnx.quant_pre_process import quant_pre_process

try:
quant_pre_process(input_model_path=model.model_path, output_model_path=output_model_path, auto_merge=True)
except Exception as e:
Expand Down
7 changes: 7 additions & 0 deletions test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from test.unit_test.utils import get_onnx_model

import numpy as np
import pytest
from onnxruntime import __version__ as OrtVersion
from onnxruntime.quantization.calibrate import CalibrationDataReader
from packaging import version

from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.vitis_ai_quantization import VitisAIQuantization
Expand All @@ -34,6 +37,10 @@ def dummy_calibration_reader(data_dir=None, batch_size=1, *args, **kwargs):
return RandomDataReader()


@pytest.mark.skipif(
version.parse(OrtVersion) >= version.parse("1.16.0"),
reason="VitisAIQuantization is not supported in ORT 1.16.0 with TensorsData",
)
def test_vitis_ai_quantization_pass(tmp_path):
# setup
input_model = get_onnx_model()
Expand Down