forked from openvinotoolkit/nncf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement FracBitsQuantizers (openvinotoolkit#1231)
* Implement FracBitsQuantizers - Implement FracBitsSymmetricQuantizer and FracBitsAsymmetricQuantizer - Implement relevant unit test Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com> * Remove self._constructed flag - Fix enable_gradients() and disable_gradients to use the base class functions - Fix num_bits setter Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
- Loading branch information
Showing
4 changed files
with
361 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
# Reference: Yang, Linjie, and Qing Jin. "Fracbits: Mixed precision quantization via fractional bit-widths." | ||
# Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 35. No. 12. 2021. | ||
|
||
from typing import Dict | ||
import torch | ||
|
||
from nncf.experimental.torch.fracbits.structs import FracBitsQuantizationMode | ||
|
||
from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter | ||
from nncf.torch.quantization.layers import ( | ||
QUANTIZATION_MODULES, AsymmetricQuantizer, PTQuantizerSpec, SymmetricQuantizer) | ||
from nncf.torch.quantization.quantize_functions import asymmetric_quantize, symmetric_quantize | ||
from nncf.torch.utils import no_jit_trace | ||
|
||
|
||
@COMPRESSION_MODULES.register() | ||
@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.SYMMETRIC) | ||
class FracBitsSymmetricQuantizer(SymmetricQuantizer): | ||
def __init__(self, qspec: PTQuantizerSpec): | ||
super().__init__(qspec) | ||
self._min_num_bits = int(0.5 * qspec.num_bits) | ||
self._max_num_bits = int(1.5 * qspec.num_bits) | ||
self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True, | ||
compression_lr_multiplier=qspec.compression_lr_multiplier) | ||
|
||
@property | ||
def frac_num_bits(self): | ||
return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits) | ||
|
||
@property | ||
def num_bits(self): | ||
if self._num_bits.dtype == torch.int32: | ||
return super().num_bits | ||
|
||
with no_jit_trace(): | ||
return self.frac_num_bits.round().int().item() | ||
|
||
@num_bits.setter | ||
def num_bits(self, num_bits: int): | ||
if num_bits < self._min_num_bits or num_bits > self._max_num_bits: | ||
raise RuntimeError( | ||
f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]") | ||
self._num_bits.fill_(num_bits) | ||
|
||
@property | ||
def is_num_bits_frozen(self) -> bool: | ||
return not self._num_bits.requires_grad | ||
|
||
def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int): | ||
if min_num_bits >= max_num_bits: | ||
raise ValueError( | ||
f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})") | ||
self._min_num_bits = min_num_bits | ||
self._max_num_bits = max_num_bits | ||
|
||
def unfreeze_num_bits(self) -> None: | ||
self._num_bits.requires_grad_(True) | ||
|
||
def freeze_num_bits(self) -> None: | ||
self._num_bits.requires_grad_(False) | ||
super().set_level_ranges() | ||
|
||
def enable_gradients(self): | ||
super().enable_gradients() | ||
self.unfreeze_num_bits() | ||
|
||
def disable_gradients(self): | ||
super().disable_gradients() | ||
self.freeze_num_bits() | ||
|
||
def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False): | ||
scaled_num_bits = 1 if self._half_range else 0 | ||
|
||
level_low, level_high, levels = self.calculate_level_ranges( | ||
num_bits - scaled_num_bits, self.signed) | ||
|
||
return symmetric_quantize(x, levels, level_low, level_high, self.scale, self.eps, | ||
skip=execute_traced_op_as_identity) | ||
|
||
def quantize(self, x, execute_traced_op_as_identity: bool = False): | ||
if self.is_num_bits_frozen: | ||
return super().quantize(x, execute_traced_op_as_identity) | ||
|
||
fl_num_bits = self.frac_num_bits.floor().int().item() | ||
ce_num_bits = fl_num_bits + 1 | ||
|
||
fl_q = self._quantize_with_n_bits( | ||
x, fl_num_bits, execute_traced_op_as_identity) | ||
ce_q = self._quantize_with_n_bits( | ||
x, ce_num_bits, execute_traced_op_as_identity) | ||
|
||
return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q | ||
|
||
def get_trainable_params(self) -> Dict[str, torch.Tensor]: | ||
return {self.SCALE_PARAM_NAME: self.scale.detach(), "num_bits": self.frac_num_bits.detach()} | ||
|
||
def _prepare_export_quantization(self, x: torch.Tensor): | ||
self.freeze_num_bits() | ||
return super()._prepare_export_quantization(x) | ||
|
||
@torch.no_grad() | ||
def get_input_range(self): | ||
self.set_level_ranges() | ||
input_low, input_high = self._get_input_low_input_high( | ||
self.scale, self.level_low, self.level_high, self.eps) | ||
return input_low, input_high | ||
|
||
|
||
@COMPRESSION_MODULES.register() | ||
@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.ASYMMETRIC) | ||
class FracBitsAsymmetricQuantizer(AsymmetricQuantizer): | ||
def __init__(self, qspec: PTQuantizerSpec): | ||
super().__init__(qspec) | ||
self._min_num_bits = int(0.5 * qspec.num_bits) | ||
self._max_num_bits = int(1.5 * qspec.num_bits) | ||
self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True, | ||
compression_lr_multiplier=qspec.compression_lr_multiplier) | ||
|
||
@property | ||
def frac_num_bits(self): | ||
return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits) | ||
|
||
@property | ||
def num_bits(self): | ||
if self._num_bits.dtype == torch.int32: | ||
return super().num_bits | ||
|
||
with no_jit_trace(): | ||
return self.frac_num_bits.round().item() | ||
|
||
@num_bits.setter | ||
def num_bits(self, num_bits: int): | ||
if num_bits < self._min_num_bits or num_bits > self._max_num_bits: | ||
raise RuntimeError( | ||
f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]") | ||
self._num_bits.fill_(num_bits) | ||
|
||
@property | ||
def is_num_bits_frozen(self) -> bool: | ||
return not self._num_bits.requires_grad | ||
|
||
def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int): | ||
if min_num_bits >= max_num_bits: | ||
raise ValueError( | ||
f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})") | ||
self._min_num_bits = min_num_bits | ||
self._max_num_bits = max_num_bits | ||
|
||
def unfreeze_num_bits(self) -> None: | ||
self._num_bits.requires_grad_(True) | ||
|
||
def freeze_num_bits(self) -> None: | ||
self._num_bits.requires_grad_(False) | ||
super().set_level_ranges() | ||
|
||
def enable_gradients(self): | ||
super().enable_gradients() | ||
self.unfreeze_num_bits() | ||
|
||
def disable_gradients(self): | ||
super().disable_gradients() | ||
self.freeze_num_bits() | ||
|
||
def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False): | ||
scaled_num_bits = 1 if self._half_range else 0 | ||
|
||
level_low, level_high, levels = self.calculate_level_ranges( | ||
num_bits - scaled_num_bits) | ||
|
||
return asymmetric_quantize(x, levels, level_low, level_high, self.input_low, self.input_range, self.eps, | ||
skip=execute_traced_op_as_identity) | ||
|
||
def quantize(self, x, execute_traced_op_as_identity: bool = False): | ||
if self.is_num_bits_frozen: | ||
return super().quantize(x, execute_traced_op_as_identity) | ||
|
||
fl_num_bits = self.frac_num_bits.floor().int().item() | ||
ce_num_bits = fl_num_bits + 1 | ||
|
||
fl_q = self._quantize_with_n_bits( | ||
x, fl_num_bits, execute_traced_op_as_identity) | ||
ce_q = self._quantize_with_n_bits( | ||
x, ce_num_bits, execute_traced_op_as_identity) | ||
|
||
return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q | ||
|
||
def get_trainable_params(self) -> Dict[str, torch.Tensor]: | ||
return {self.INPUT_LOW_PARAM_NAME: self.input_low.detach(), | ||
self.INPUT_RANGE_PARAM_NAME: self.input_range.detach(), | ||
"num_bits": self._num_bits.detach()} | ||
|
||
def _prepare_export_quantization(self, x: torch.Tensor): | ||
self.freeze_num_bits() | ||
return super()._prepare_export_quantization(x) | ||
|
||
@torch.no_grad() | ||
def get_input_range(self): | ||
self.set_level_ranges() | ||
input_low, input_high = self._get_input_low_input_high(self.input_range, | ||
self.input_low, | ||
self.levels, | ||
self.eps) | ||
return input_low, input_high |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
class FracBitsQuantizationMode: | ||
SYMMETRIC = 'fracbits_symmetric' | ||
ASYMMETRIC = 'fracbits_asymmetric' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from nncf.common.utils.logger import logger as nncf_logger | ||
from nncf.experimental.torch.fracbits.quantizer import FracBitsAsymmetricQuantizer, FracBitsSymmetricQuantizer | ||
from nncf.experimental.torch.fracbits.structs import FracBitsQuantizationMode | ||
from nncf.torch.quantization.layers import PTQuantizerSpec | ||
|
||
#pylint: disable=redefined-outer-name | ||
|
||
|
||
def set_manual_seed(): | ||
torch.manual_seed(3003) | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def linear_problem(num_bits: int = 4, sigma: float = 0.2): | ||
set_manual_seed() | ||
|
||
levels = 2 ** num_bits | ||
w = 1 / levels * (torch.randint(0, levels, size=[100, 10]) - levels // 2) | ||
x = torch.randn([1000, 10]) | ||
y = w.mm(x.t()) | ||
y += sigma * torch.randn_like(y) | ||
|
||
return w, x, y, num_bits, sigma | ||
|
||
|
||
@pytest.fixture() | ||
def qspec(request): | ||
return PTQuantizerSpec(num_bits=8, | ||
mode=request.param, | ||
signedness_to_force=None, | ||
scale_shape=(1, 1), | ||
narrow_range=False, | ||
half_range=False, | ||
logarithm_scale=False) | ||
|
||
|
||
@pytest.mark.parametrize("add_bitwidth_loss", [True, False]) | ||
@pytest.mark.parametrize("device", ["cpu", "cuda"]) | ||
@pytest.mark.parametrize("qspec", | ||
[FracBitsQuantizationMode.ASYMMETRIC, FracBitsQuantizationMode.SYMMETRIC], indirect=["qspec"]) | ||
def test_quantization(linear_problem, qspec, device, add_bitwidth_loss): | ||
""" | ||
Test quantization for the simple linear problem. | ||
The weight is filled with the random integer in range with [-2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1], | ||
then scaled with 1 / (bit-width). Thus, it will finally be in [-0.5, 0.5]. | ||
We initiate the quantizers input_low and input_high smaller than [-0.5, 0.5] by multiplying 0.1 to both limits. | ||
Let SGD optimizer to learn quantizer parameters with a MSE loss for the linear model. | ||
Check whether input_low and input_high is expanded to [-0.5, 0.5] to compensate quantization errors, | ||
and the MSE loss is also minimized. If we add target bit_bidth loss, | ||
we have to check whether our quantizer's learnable bit_width also goes to the target bit_width. | ||
""" | ||
w, x, y, bit_width, sigma = linear_problem | ||
|
||
w, x, y = w.to(device=device), x.to(device=device), y.to(device=device) | ||
|
||
quant = FracBitsAsymmetricQuantizer( | ||
qspec) if qspec.mode == FracBitsQuantizationMode.ASYMMETRIC else FracBitsSymmetricQuantizer(qspec) | ||
|
||
init_input_low = torch.FloatTensor([w.min() * 0.1]) | ||
init_input_high = torch.FloatTensor([w.max() * 0.1]) | ||
|
||
quant.apply_minmax_init(init_input_low, init_input_high) | ||
quant = quant.to(w.device) | ||
criteria = nn.MSELoss() | ||
|
||
optim = torch.optim.SGD(quant.parameters(), lr=1e-1) | ||
|
||
for _ in range(100): | ||
optim.zero_grad() | ||
loss = criteria(y, quant(w).mm(x.t())) | ||
|
||
if add_bitwidth_loss: | ||
loss += criteria(bit_width * | ||
torch.ones_like(quant.frac_num_bits), quant.frac_num_bits) | ||
|
||
loss.backward() | ||
optim.step() | ||
|
||
eps = 0.05 | ||
ub_mse_loss = 1.1 * (sigma ** 2) | ||
ub_left_q_w = w.min() + eps | ||
lb_right_q_w = w.max() - eps | ||
|
||
with torch.no_grad(): | ||
loss = criteria(y, quant(w).mm(x.t())).item() | ||
nncf_logger.debug( | ||
f"loss={loss:.3f} should be lower than ub_mse_loss={ub_mse_loss:.3f}.") | ||
assert loss < ub_mse_loss | ||
|
||
left_q_w, right_q_w = quant.get_input_range() | ||
left_q_w, right_q_w = left_q_w.item(), right_q_w.item() | ||
|
||
nncf_logger.debug(f"[left_q_w, right_q_w]^C [{left_q_w:.3f}, {right_q_w:.3f}]^C should be included in " | ||
f"[ub_left_q_w, lb_right_q_w]^C = [{ub_left_q_w:.3f}, {lb_right_q_w:.3f}]^C.") | ||
|
||
assert left_q_w < ub_left_q_w | ||
assert lb_right_q_w < right_q_w | ||
|
||
if add_bitwidth_loss: | ||
assert quant.num_bits == bit_width |