Skip to content

Commit

Permalink
[quant] Add example for lowering quantized dynamic linear pattern thr…
Browse files Browse the repository at this point in the history
…ough delegation (pytorch#90640)

Summary: Only the pattern part, will leave the delegation example to Chen

Test Plan: buck run executorch/exir/tests:quant_lowering_custom_backend_pass -- "executorch.exir.tests.test_quant_lowering_custom_backend_pass.TestQuantLoweringCustomBackendPass.test_quantized_linear_dynamic"

Reviewed By: cccclai

Pull Request resolved: pytorch#90640
Approved by: https://github.com/cccclai
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Dec 13, 2022
1 parent b6f114c commit 94b9bb3
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion torch/ao/quantization/fx/_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def choose_qparams_tensor(
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> Tuple[float, int]:
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Given an input Tensor, derive the per tensor affine quantization parameter
(scale and zero_point) for target quantized Tensor from the Tensor
Expand All @@ -211,6 +211,17 @@ def choose_qparams_tensor(
scale, zero_point = observer.calculate_qparams()
return (scale, zero_point)

@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
def choose_qparams_tensor_meta(
input: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
return torch.empty(1, dtype=torch.float, device=input.device), torch.empty(1, dtype=torch.int32, device=input.device)

# Helper function used to implement per-channel quantization against any axis
def _permute_to_axis_zero(x, axis):
new_axis_list = list(range(x.dim()))
Expand Down

0 comments on commit 94b9bb3

Please sign in to comment.