diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 3ca89856..6a8175cd 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -54,6 +54,19 @@ def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): else: return w + def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: torch.Tensor): + # get repeated count + repeat_count = w.weight.data.shape[-1] // zeros.shape[-1] + + # get zeros and scales in correct shape + zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) + scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) + + # dequantize + w = (w.weight.data - zeros) * scales + + return w + def quantize(self): for i in tqdm(range(len(self.modules)), desc="AWQ"): # [STEP 1]: Get layer, extract linear modules, extract input features