Skip to content

Commit

Permalink
Fix subgraph quantization regression in onnxruntime 1.17 (#19421)
Browse files Browse the repository at this point in the history
As per title, fixes
#19418

ONNX Runtime 1.17 broke the quantization of ONNX models with subgraphs
where initializers are placed on the top-level graph, while different
subgraphs use the same initializer.
  • Loading branch information
fxmarty authored Feb 13, 2024
1 parent 5c7e6b2 commit 1e10cdb
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
10 changes: 8 additions & 2 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,9 +1332,15 @@ def _dequantize_value(self, value_name):
if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
quantized_value = self.quantized_value_map[value_name]
# Add DequantizeLinear Node for this input

scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
# axis is not specified so scale_init must be a scalar.
assert onnx.numpy_helper.to_array(scale_init).size == 1

# In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
if self.model.model.producer_name != "onnx-quantizer" or (
self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
):
# axis is not specified so scale_init must be a scalar.
assert onnx.numpy_helper.to_array(scale_init).size == 1

dqlinear_name = value_name + "_DequantizeLinear"
dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
Expand Down
64 changes: 64 additions & 0 deletions onnxruntime/test/python/quantization/test_subgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import os
import tempfile
import unittest
import urllib.request

import onnx

from onnxruntime.quantization import quantize_dynamic


class TestDynamicQuantizationSubgraph(unittest.TestCase):
def test_dynamic_quantization_subgraph(self):
with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx")
quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx")
urllib.request.urlretrieve(
"https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path
)

quantize_dynamic(
model_input=onnx_path,
model_output=quantized_onnx_path,
per_channel=True,
op_types_to_quantize=[
"Conv",
"MatMul",
"Attention",
"LSTM",
"Gather",
"Transpose",
"EmbedLayerNormalization",
],
extra_options={"EnableSubgraph": True},
)
model = onnx.load(quantized_onnx_path)

# The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs.
# We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph.
found_gather_quantized = False
for initializer in model.graph.initializer:
if initializer.name == "shared.weight_merged_0_quantized":
found_gather_quantized = True
break
self.assertTrue(found_gather_quantized)

found_gather_scale = False
for initializer in model.graph.initializer:
if initializer.name == "shared.weight_merged_0_scale":
found_gather_scale = True
break
self.assertTrue(found_gather_scale)

# No initializers related to the Gather node should be attached to the subgraphs.
for node in model.graph.node:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
for initializer in attr.g.initializer:
self.assertTrue("shared.weight" not in initializer.name)

0 comments on commit 1e10cdb

Please sign in to comment.