Skip to content

Commit

Permalink
Implement FracBitsQuantizers (openvinotoolkit#1231)
Browse files Browse the repository at this point in the history
* Implement FracBitsQuantizers

 - Implement FracBitsSymmetricQuantizer and FracBitsAsymmetricQuantizer
 - Implement relevant unit test

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Remove self._constructed flag

 - Fix enable_gradients() and disable_gradients to use the base class functions
 - Fix num_bits setter

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Jul 22, 2022
1 parent 7827f19 commit 53df688
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nncf/experimental/torch/fracbits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
216 changes: 216 additions & 0 deletions nncf/experimental/torch/fracbits/quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Reference: Yang, Linjie, and Qing Jin. "Fracbits: Mixed precision quantization via fractional bit-widths."
# Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 35. No. 12. 2021.

from typing import Dict
import torch

from nncf.experimental.torch.fracbits.structs import FracBitsQuantizationMode

from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter
from nncf.torch.quantization.layers import (
QUANTIZATION_MODULES, AsymmetricQuantizer, PTQuantizerSpec, SymmetricQuantizer)
from nncf.torch.quantization.quantize_functions import asymmetric_quantize, symmetric_quantize
from nncf.torch.utils import no_jit_trace


@COMPRESSION_MODULES.register()
@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.SYMMETRIC)
class FracBitsSymmetricQuantizer(SymmetricQuantizer):
def __init__(self, qspec: PTQuantizerSpec):
super().__init__(qspec)
self._min_num_bits = int(0.5 * qspec.num_bits)
self._max_num_bits = int(1.5 * qspec.num_bits)
self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True,
compression_lr_multiplier=qspec.compression_lr_multiplier)

@property
def frac_num_bits(self):
return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits)

@property
def num_bits(self):
if self._num_bits.dtype == torch.int32:
return super().num_bits

with no_jit_trace():
return self.frac_num_bits.round().int().item()

@num_bits.setter
def num_bits(self, num_bits: int):
if num_bits < self._min_num_bits or num_bits > self._max_num_bits:
raise RuntimeError(
f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]")
self._num_bits.fill_(num_bits)

@property
def is_num_bits_frozen(self) -> bool:
return not self._num_bits.requires_grad

def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int):
if min_num_bits >= max_num_bits:
raise ValueError(
f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})")
self._min_num_bits = min_num_bits
self._max_num_bits = max_num_bits

def unfreeze_num_bits(self) -> None:
self._num_bits.requires_grad_(True)

def freeze_num_bits(self) -> None:
self._num_bits.requires_grad_(False)
super().set_level_ranges()

def enable_gradients(self):
super().enable_gradients()
self.unfreeze_num_bits()

def disable_gradients(self):
super().disable_gradients()
self.freeze_num_bits()

def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False):
scaled_num_bits = 1 if self._half_range else 0

level_low, level_high, levels = self.calculate_level_ranges(
num_bits - scaled_num_bits, self.signed)

return symmetric_quantize(x, levels, level_low, level_high, self.scale, self.eps,
skip=execute_traced_op_as_identity)

def quantize(self, x, execute_traced_op_as_identity: bool = False):
if self.is_num_bits_frozen:
return super().quantize(x, execute_traced_op_as_identity)

fl_num_bits = self.frac_num_bits.floor().int().item()
ce_num_bits = fl_num_bits + 1

fl_q = self._quantize_with_n_bits(
x, fl_num_bits, execute_traced_op_as_identity)
ce_q = self._quantize_with_n_bits(
x, ce_num_bits, execute_traced_op_as_identity)

return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q

def get_trainable_params(self) -> Dict[str, torch.Tensor]:
return {self.SCALE_PARAM_NAME: self.scale.detach(), "num_bits": self.frac_num_bits.detach()}

def _prepare_export_quantization(self, x: torch.Tensor):
self.freeze_num_bits()
return super()._prepare_export_quantization(x)

@torch.no_grad()
def get_input_range(self):
self.set_level_ranges()
input_low, input_high = self._get_input_low_input_high(
self.scale, self.level_low, self.level_high, self.eps)
return input_low, input_high


@COMPRESSION_MODULES.register()
@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.ASYMMETRIC)
class FracBitsAsymmetricQuantizer(AsymmetricQuantizer):
def __init__(self, qspec: PTQuantizerSpec):
super().__init__(qspec)
self._min_num_bits = int(0.5 * qspec.num_bits)
self._max_num_bits = int(1.5 * qspec.num_bits)
self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True,
compression_lr_multiplier=qspec.compression_lr_multiplier)

@property
def frac_num_bits(self):
return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits)

@property
def num_bits(self):
if self._num_bits.dtype == torch.int32:
return super().num_bits

with no_jit_trace():
return self.frac_num_bits.round().item()

@num_bits.setter
def num_bits(self, num_bits: int):
if num_bits < self._min_num_bits or num_bits > self._max_num_bits:
raise RuntimeError(
f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]")
self._num_bits.fill_(num_bits)

@property
def is_num_bits_frozen(self) -> bool:
return not self._num_bits.requires_grad

def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int):
if min_num_bits >= max_num_bits:
raise ValueError(
f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})")
self._min_num_bits = min_num_bits
self._max_num_bits = max_num_bits

def unfreeze_num_bits(self) -> None:
self._num_bits.requires_grad_(True)

def freeze_num_bits(self) -> None:
self._num_bits.requires_grad_(False)
super().set_level_ranges()

def enable_gradients(self):
super().enable_gradients()
self.unfreeze_num_bits()

def disable_gradients(self):
super().disable_gradients()
self.freeze_num_bits()

def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False):
scaled_num_bits = 1 if self._half_range else 0

level_low, level_high, levels = self.calculate_level_ranges(
num_bits - scaled_num_bits)

return asymmetric_quantize(x, levels, level_low, level_high, self.input_low, self.input_range, self.eps,
skip=execute_traced_op_as_identity)

def quantize(self, x, execute_traced_op_as_identity: bool = False):
if self.is_num_bits_frozen:
return super().quantize(x, execute_traced_op_as_identity)

fl_num_bits = self.frac_num_bits.floor().int().item()
ce_num_bits = fl_num_bits + 1

fl_q = self._quantize_with_n_bits(
x, fl_num_bits, execute_traced_op_as_identity)
ce_q = self._quantize_with_n_bits(
x, ce_num_bits, execute_traced_op_as_identity)

return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q

def get_trainable_params(self) -> Dict[str, torch.Tensor]:
return {self.INPUT_LOW_PARAM_NAME: self.input_low.detach(),
self.INPUT_RANGE_PARAM_NAME: self.input_range.detach(),
"num_bits": self._num_bits.detach()}

def _prepare_export_quantization(self, x: torch.Tensor):
self.freeze_num_bits()
return super()._prepare_export_quantization(x)

@torch.no_grad()
def get_input_range(self):
self.set_level_ranges()
input_low, input_high = self._get_input_low_input_high(self.input_range,
self.input_low,
self.levels,
self.eps)
return input_low, input_high
16 changes: 16 additions & 0 deletions nncf/experimental/torch/fracbits/structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

class FracBitsQuantizationMode:
SYMMETRIC = 'fracbits_symmetric'
ASYMMETRIC = 'fracbits_asymmetric'
117 changes: 117 additions & 0 deletions tests/torch/experimental/fracbits/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import pytest
import torch
from torch import nn

from nncf.common.utils.logger import logger as nncf_logger
from nncf.experimental.torch.fracbits.quantizer import FracBitsAsymmetricQuantizer, FracBitsSymmetricQuantizer
from nncf.experimental.torch.fracbits.structs import FracBitsQuantizationMode
from nncf.torch.quantization.layers import PTQuantizerSpec

#pylint: disable=redefined-outer-name


def set_manual_seed():
torch.manual_seed(3003)


@pytest.fixture(scope="function")
def linear_problem(num_bits: int = 4, sigma: float = 0.2):
set_manual_seed()

levels = 2 ** num_bits
w = 1 / levels * (torch.randint(0, levels, size=[100, 10]) - levels // 2)
x = torch.randn([1000, 10])
y = w.mm(x.t())
y += sigma * torch.randn_like(y)

return w, x, y, num_bits, sigma


@pytest.fixture()
def qspec(request):
return PTQuantizerSpec(num_bits=8,
mode=request.param,
signedness_to_force=None,
scale_shape=(1, 1),
narrow_range=False,
half_range=False,
logarithm_scale=False)


@pytest.mark.parametrize("add_bitwidth_loss", [True, False])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("qspec",
[FracBitsQuantizationMode.ASYMMETRIC, FracBitsQuantizationMode.SYMMETRIC], indirect=["qspec"])
def test_quantization(linear_problem, qspec, device, add_bitwidth_loss):
"""
Test quantization for the simple linear problem.
The weight is filled with the random integer in range with [-2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1],
then scaled with 1 / (bit-width). Thus, it will finally be in [-0.5, 0.5].
We initiate the quantizers input_low and input_high smaller than [-0.5, 0.5] by multiplying 0.1 to both limits.
Let SGD optimizer to learn quantizer parameters with a MSE loss for the linear model.
Check whether input_low and input_high is expanded to [-0.5, 0.5] to compensate quantization errors,
and the MSE loss is also minimized. If we add target bit_bidth loss,
we have to check whether our quantizer's learnable bit_width also goes to the target bit_width.
"""
w, x, y, bit_width, sigma = linear_problem

w, x, y = w.to(device=device), x.to(device=device), y.to(device=device)

quant = FracBitsAsymmetricQuantizer(
qspec) if qspec.mode == FracBitsQuantizationMode.ASYMMETRIC else FracBitsSymmetricQuantizer(qspec)

init_input_low = torch.FloatTensor([w.min() * 0.1])
init_input_high = torch.FloatTensor([w.max() * 0.1])

quant.apply_minmax_init(init_input_low, init_input_high)
quant = quant.to(w.device)
criteria = nn.MSELoss()

optim = torch.optim.SGD(quant.parameters(), lr=1e-1)

for _ in range(100):
optim.zero_grad()
loss = criteria(y, quant(w).mm(x.t()))

if add_bitwidth_loss:
loss += criteria(bit_width *
torch.ones_like(quant.frac_num_bits), quant.frac_num_bits)

loss.backward()
optim.step()

eps = 0.05
ub_mse_loss = 1.1 * (sigma ** 2)
ub_left_q_w = w.min() + eps
lb_right_q_w = w.max() - eps

with torch.no_grad():
loss = criteria(y, quant(w).mm(x.t())).item()
nncf_logger.debug(
f"loss={loss:.3f} should be lower than ub_mse_loss={ub_mse_loss:.3f}.")
assert loss < ub_mse_loss

left_q_w, right_q_w = quant.get_input_range()
left_q_w, right_q_w = left_q_w.item(), right_q_w.item()

nncf_logger.debug(f"[left_q_w, right_q_w]^C [{left_q_w:.3f}, {right_q_w:.3f}]^C should be included in "
f"[ub_left_q_w, lb_right_q_w]^C = [{ub_left_q_w:.3f}, {lb_right_q_w:.3f}]^C.")

assert left_q_w < ub_left_q_w
assert lb_right_q_w < right_q_w

if add_bitwidth_loss:
assert quant.num_bits == bit_width

0 comments on commit 53df688

Please sign in to comment.