Skip to content

Commit

Permalink
Fix static quantization for QDQ and Percentile distribution (#17649)
Browse files Browse the repository at this point in the history
### Description
One quantization case was not covered by the current list of unit tests.
This PR adds a unit test to cover that case with the fix. It fixes the
issue #17619.



### Motivation and Context
  • Loading branch information
xadupre authored Sep 25, 2023
1 parent df15a3a commit 905faea
Show file tree
Hide file tree
Showing 7 changed files with 13,909 additions and 7 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/quantization/qlinearconv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel {
W_zero_point_value = W_zero_point_data[0];
for (int64_t i = 1; i < W_zero_point_size; i++) {
ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value,
"QLinearConv : zero point of per-channel filter must be same");
"QLinearConv : zero point of per-channel filter must be same. "
"This happens by design if the quantization is symmetric.");
}
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class TensorData:
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"])
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]]
self.data[k] = TensorData(lowest=v[0], highest=v[1])
continue
if len(v) == 4:
self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3])
self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3])
continue
raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.")
if not isinstance(v, TensorData):
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/operators/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def quantize(self):
nodes,
) = self.quantizer.quantize_activation(node, [0])
quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[1], onnx_proto.TensorProto.INT8, 0
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)
quantized_input_names.append(quant_weight_tuple[0])
zero_point_names.append(quant_weight_tuple[1])
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/quantization/operators/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def quantize(self):
R.dims[0] = R_num_dir * R_4_hidden_size

quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[1], onnx_proto.TensorProto.INT8, 0
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)
quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[2], onnx_proto.TensorProto.INT8, 0
node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)

W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None):
raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.")
q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel(
weight_name,
self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType,
# Quantization type is forced to be TensorProto.INT8.
# when the expected value would be (see below)
# self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType.
# QLinearConv expects to have a unique value for all channels.
# This code does not enforce that but it is necessarily the case when the
# quantization is symmetric (as for INT8).
onnx_proto.TensorProto.INT8,
axis,
keep_float_weight=self.add_qdq_pair_to_weight,
)
Expand Down
Loading

0 comments on commit 905faea

Please sign in to comment.