Skip to content

Commit

Permalink
[quant] Fix per channel weight observer (pytorch#85883)
Browse files Browse the repository at this point in the history
Summary: `per_channel_weight_observer_range_neg_127_to_127` now correctly uses `PerChannelMinMaxObserver` instead of `MinMaxObserver`

Test Plan:
Adds a new test `quantization.core.test_top_level_apis
` to instansiate and run `forward()` on all `default` observers

Differential Revision: D39916482

Pull Request resolved: pytorch#85883
Approved by: https://github.com/salilsdesai
  • Loading branch information
digantdesai authored and pytorchmergebot committed Sep 30, 2022
1 parent 6a5550f commit 071f875
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
61 changes: 61 additions & 0 deletions test/quantization/core/test_top_level_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Owner(s): ["oncall: quantization"]

import torch
import torch.ao.quantization
from torch.testing._internal.common_utils import TestCase


class TestDefaultObservers(TestCase):
observers = [
"default_affine_fixed_qparams_observer",
"default_debug_observer",
"default_dynamic_quant_observer",
"default_placeholder_observer",
"default_fixed_qparams_range_0to1_observer",
"default_fixed_qparams_range_neg1to1_observer",
"default_float_qparams_observer",
"default_float_qparams_observer_4bit",
"default_histogram_observer",
"default_observer",
"default_per_channel_weight_observer",
"default_reuse_input_observer",
"default_symmetric_fixed_qparams_observer",
"default_weight_observer",
"per_channel_weight_observer_range_neg_127_to_127",
"weight_observer_range_neg_127_to_127",
]

fake_quants = [
"default_affine_fixed_qparams_fake_quant",
"default_dynamic_fake_quant",
"default_embedding_fake_quant",
"default_embedding_fake_quant_4bit",
"default_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fused_act_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"default_fused_wt_fake_quant",
"default_histogram_fake_quant",
"default_per_channel_weight_fake_quant",
"default_symmetric_fixed_qparams_fake_quant",
"default_weight_fake_quant",
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
"fused_wt_fake_quant_range_neg_127_to_127",
]

def _get_observer_ins(self, observer):
obs_func = getattr(torch.ao.quantization, observer)
return obs_func()

def test_observers(self) -> None:
t = torch.rand(1, 2, 3, 4)
for observer in self.observers:
obs = self._get_observer_ins(observer)
obs.forward(t)

def test_fake_quants(self) -> None:
t = torch.rand(1, 2, 3, 4)
for observer in self.fake_quants:
obs = self._get_observer_ins(observer)
obs.forward(t)
14 changes: 8 additions & 6 deletions torch/ao/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,14 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
"""

fused_per_channel_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=-127,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
eps=2 ** -12)
fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-127,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
eps=2 ** -12)

"""
Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
"""
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,7 @@ def load_observer_state_dict(mod, obs_dict):
weight quantization is supported, such as `fbgemm`.
"""

per_channel_weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric,
quant_min=-127, quant_max=127, eps=2 ** -12)
"""
Expand Down

0 comments on commit 071f875

Please sign in to comment.