forked from AutoGPTQ/AutoGPTQ
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Supporting uint4 inference of pre-quantized models in HPU (AutoGPTQ#689)
* Supporting llama uint4 quantization using AutoGPTQ (AutoGPTQ#1) * Supporting llama int4 quantization using AutoGPTQ * Running only PT code (similar to cuda_old) on HPU * Testing convert_from_int4 * Started cleanup * code cleanup * Added weight reshape in preprocessing Added llama7b generation hpu test * Changed reshape to match matmul (still not accurate) and fixed q4 test * Fixing zero points * Update pack function * Fixed accuracy * Uncommented exllama * Marlin test fix + added hpu bias test * Review comments * Removed hpu pack until we'll implement it in HPU --------- Co-authored-by: yan tomsinsky <ytomsinsky@habana.ai> * Added assert when g_idx is not trivial (AutoGPTQ#2) --------- Co-authored-by: yan tomsinsky <ytomsinsky@habana.ai>
- Loading branch information
1 parent
ea829c7
commit b57bea0
Showing
6 changed files
with
483 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import math | ||
from logging import getLogger | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import transformers | ||
import habana_frameworks.torch.core as htcore | ||
|
||
logger = getLogger(__name__) | ||
|
||
def pack_tensor(input, bits = 4): | ||
normal = input.to(torch.int32) | ||
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) | ||
i = 0 | ||
col = 0 | ||
while col < q.shape[1]: | ||
for j in range(i, i + (32 // bits)): | ||
q[:, col] |= normal[:, j] << (bits * (j - i)) | ||
i += 32 // bits | ||
col += 1 | ||
q = q.to(torch.int32) | ||
return q | ||
|
||
class QuantLinear(nn.Module): | ||
QUANT_TYPE = "hpu" | ||
|
||
def __init__( | ||
self, | ||
bits, | ||
group_size, | ||
infeatures, | ||
outfeatures, | ||
bias, | ||
use_cuda_fp16=True, | ||
kernel_switch_threshold=128, | ||
trainable=False, | ||
weight_dtype=torch.float16, | ||
): | ||
logger.debug(f"qlinear_hpu QuantLinear::__init__ {bits=}, {group_size=}, {infeatures=}, {outfeatures=}, {bias=}, {use_cuda_fp16=}, {kernel_switch_threshold=}, {trainable=}, {weight_dtype=}") | ||
super().__init__() | ||
if bits != 4: | ||
raise NotImplementedError("Only 4 bits are supported.") | ||
|
||
self.infeatures = infeatures | ||
self.outfeatures = outfeatures | ||
self.bits = bits | ||
self.group_size = group_size if group_size != -1 else infeatures | ||
self.maxq = 2**self.bits - 1 | ||
|
||
self.register_buffer( | ||
"qweight", | ||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), | ||
) | ||
self.register_buffer( | ||
"qzeros", | ||
torch.zeros( | ||
( | ||
math.ceil(infeatures / self.group_size), | ||
outfeatures // 32 * self.bits, | ||
), | ||
dtype=torch.int32, | ||
), | ||
) | ||
self.register_buffer( | ||
"scales", | ||
torch.zeros( | ||
(math.ceil(infeatures / self.group_size), outfeatures), | ||
dtype=weight_dtype, | ||
), | ||
) | ||
self.register_buffer( | ||
"g_idx", | ||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), | ||
) | ||
|
||
if bias: | ||
self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) | ||
else: | ||
self.bias = None | ||
self.half_indim = self.infeatures // 2 | ||
|
||
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0) | ||
|
||
def _preprocessing(self): | ||
self.qweight = self.qweight.cpu() | ||
weight = self.unpack_weight_from_cuda_old_format() | ||
new_qweight = pack_tensor(weight) | ||
self.qweight = new_qweight.to('hpu') | ||
|
||
# TODO: Support group indexing and remove the check | ||
columns = self.qweight.shape[0] | ||
g_idx_trivial = [i // self.group_size for i in range(columns)] | ||
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32) | ||
assert torch.equal(self.g_idx, g_idx_trivial), "Non-trivial tensor g_idx is not supported" | ||
|
||
zeros = self.unpack_zeros_from_cuda_old_format().cpu() | ||
new_qzeros = pack_tensor(zeros) | ||
self.qzeros = new_qzeros.to('hpu') | ||
|
||
def post_init(self): | ||
self._preprocessing() | ||
|
||
def pack(self, linear, scales, zeros, g_idx): | ||
#TODO: implement | ||
raise NotImplementedError("QuantLinear HPU currently doesn't support packing") | ||
|
||
def set_packed(self, qlinear_cls): | ||
self.qweight = qlinear_cls.qweight | ||
self.qzeros = qlinear_cls.qzeros | ||
self.scales = qlinear_cls.scales | ||
self.bias = qlinear_cls.bias | ||
|
||
def forward(self, x): | ||
x_dtype = x.dtype | ||
out_shape = x.shape[:-1] + (self.outfeatures,) | ||
x = x.reshape(-1, x.shape[-1]) | ||
scales = self.scales | ||
qweight = self.qweight | ||
zeros = self.qzeros | ||
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, x_dtype) | ||
out = torch.matmul(x, weight) | ||
out = out.reshape(out_shape) | ||
out = out + self.bias if self.bias is not None else out | ||
return out | ||
|
||
def unpack_zeros_from_cuda_old_format(self): | ||
zeros = torch.bitwise_right_shift( | ||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), | ||
self.wf.unsqueeze(0), | ||
).to(torch.int16 if self.bits == 8 else torch.int8) | ||
|
||
zeros = zeros + 1 | ||
zeros = torch.bitwise_and( | ||
zeros, (2**self.bits) - 1 | ||
).to(self.scales.dtype) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. | ||
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) | ||
return zeros | ||
|
||
def unpack_weight_from_cuda_old_format(self): | ||
weight = torch.bitwise_right_shift( | ||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), | ||
self.wf.unsqueeze(-1), | ||
).to(torch.int16 if self.bits == 8 else torch.int8) | ||
weight = torch.bitwise_and(weight, (2**self.bits) - 1) | ||
weight = weight.reshape((weight.shape[0]*weight.shape[1], weight.shape[2])) | ||
return weight | ||
|
||
__all__ = ["QuantLinear"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import numpy as np | ||
import math | ||
import torch | ||
import pytest | ||
import habana_frameworks.torch.core as htcore | ||
|
||
def _convert_to_tensor_list(tensor_or_tensors): | ||
if isinstance(tensor_or_tensors, tuple): | ||
return list(tensor_or_tensors) | ||
elif isinstance(tensor_or_tensors, list): | ||
return tensor_or_tensors | ||
elif isinstance(tensor_or_tensors, torch.Tensor): | ||
# You can't return list(tensor_or_tensors), because it will fail on 0-d tensors | ||
result_list = [] | ||
result_list.append(tensor_or_tensors) | ||
return result_list | ||
else: | ||
raise TypeError("Can not convert outputs") | ||
|
||
def compare_tensors(hpu_tensors, cpu_tensors, atol, rtol, assert_enable=True): | ||
hpu_tensors = _convert_to_tensor_list(hpu_tensors) | ||
cpu_tensors = _convert_to_tensor_list(cpu_tensors) | ||
assert len(hpu_tensors) == len(cpu_tensors) | ||
|
||
hpu_tensors = [tensor.to('cpu') if tensor is not None else tensor for tensor in hpu_tensors] | ||
|
||
for i in range(len(hpu_tensors)): | ||
if cpu_tensors[i] is None and hpu_tensors[i] is None: | ||
continue | ||
|
||
hpu_tensors[i] = ( | ||
hpu_tensors[i].float() | ||
if hpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn] | ||
else hpu_tensors[i] | ||
) | ||
cpu_tensors[i] = ( | ||
cpu_tensors[i].float() | ||
if cpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn] | ||
else cpu_tensors[i] | ||
) | ||
if assert_enable: | ||
np.testing.assert_allclose( | ||
hpu_tensors[i].detach().numpy(), | ||
cpu_tensors[i].detach().numpy(), | ||
atol=atol, | ||
rtol=rtol, | ||
) | ||
else: | ||
print("hpu_result[{}]".format(i), hpu_tensors[i].detach().numpy()) | ||
print("cpu_result[{}]".format(i), cpu_tensors[i].detach().numpy()) | ||
return np.allclose( | ||
hpu_tensors[i].detach().numpy(), | ||
cpu_tensors[i].detach().numpy(), | ||
atol=atol, | ||
rtol=rtol, | ||
equal_nan=True, | ||
) | ||
|
||
# taken from AutoGPTQ/tests/test_repacking.py | ||
def gen_quant4(k, n, groupsize=-1, bias=False): | ||
maxq = 2 ** 4 - 1 | ||
w = torch.randn((k, n), dtype=torch.bfloat16, device="cpu") | ||
|
||
original_w = w.clone() | ||
|
||
if groupsize != -1: | ||
w = w.reshape((-1, groupsize, n)) | ||
w = w.permute(1, 0, 2) | ||
w = w.reshape((groupsize, -1)) | ||
|
||
s = torch.max(torch.abs(w), 0, keepdim=True)[0] | ||
s *= 2 / maxq | ||
|
||
# Quantize. | ||
w = torch.round(w / s).int() | ||
|
||
# Unsigned storage. | ||
w += (maxq + 1) // 2 | ||
w = torch.clamp(w, 0, maxq) | ||
|
||
# Dequantize. | ||
ref = (w - (maxq + 1) // 2).bfloat16() * s | ||
|
||
if groupsize != -1: | ||
def reshape(w): | ||
w = w.reshape((groupsize, -1, n)) | ||
w = w.permute(1, 0, 2) | ||
w = w.reshape((k, n)).contiguous() | ||
return w | ||
ref = reshape(ref) | ||
w = reshape(w) | ||
|
||
s = s.reshape((-1, n)).contiguous() | ||
linear = torch.nn.Linear(k, n, bias=bias) | ||
linear.weight.data = ref.t() | ||
|
||
return original_w, linear, s | ||
|
||
@pytest.mark.parametrize("bits", [4]) | ||
@pytest.mark.parametrize("group_size", [16, 32, 128]) | ||
@pytest.mark.parametrize("infeatures", [64, 128, 512, 4096, 11008]) | ||
@pytest.mark.parametrize("outfeatures", [64, 128, 512, 4096, 11008]) | ||
@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"]) | ||
@pytest.mark.parametrize("scales_value, weight_value, zeros_value", [("normal", "normal", "normal"), ("normal", "normal", "range"), ("normal", "normal", "zeros"), ("ones", "zeros", "zeros"), ("ones", "zeros", "eights"), ("ones", "range", "zeros"), ("ones", "range", "ones"), ("ones", "7", "ones"), ("ones", "zeros", "range"),("ones", "zeros", "ones"), ("ones", "range", "range"), ("range", "range", "range"), ("range", "range", "zeros")]) | ||
@pytest.mark.parametrize("weight_dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) | ||
def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_value, weight_value, zeros_value, weight_dtype): | ||
qweight_shape_0 = infeatures // 32 * bits | ||
qzeros_shape_0 = math.ceil(infeatures / group_size) | ||
qzeros_shape_1 = outfeatures // 32 * bits | ||
if qweight_shape_0 == 0 or qzeros_shape_0 == 0 or qzeros_shape_1 == 0: | ||
pytest.skip(f"{qweight_shape_0=} == 0 or {qzeros_shape_0=} == 0 or {qzeros_shape_1=} == 0") | ||
if infeatures < group_size: | ||
pytest.skip(f"{infeatures=} < {group_size=}") | ||
if infeatures != outfeatures: | ||
pytest.skip(f"{infeatures=} != {outfeatures=}") | ||
trainable = False | ||
use_cuda_fp16 = False | ||
kernel_switch_threshold = 128 | ||
from auto_gptq.nn_modules.qlinear import qlinear_hpu, qlinear_cuda_old | ||
quant_hpu = qlinear_hpu.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu") | ||
# Cuda old implementation is the reference, also runs on hpu | ||
quant_ref_cuda_old = qlinear_cuda_old.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu") | ||
input = torch.rand((infeatures, outfeatures), dtype=weight_dtype).to("hpu") | ||
_, linear, s = gen_quant4(infeatures, outfeatures, group_size, bias) | ||
|
||
if scales_value == "ones": | ||
s = torch.ones_like(s) | ||
if scales_value == "range": | ||
range_t = torch.tensor(list(range(1, s.numel()+1)), dtype=torch.int32) | ||
shape_s = s.shape | ||
s = (torch.ones(s.numel()) * range_t).reshape(shape_s).contiguous() | ||
|
||
if weight_value == "ones": | ||
linear.weight = torch.nn.Parameter(torch.ones_like(linear.weight)) | ||
elif weight_value == "zeros": | ||
linear.weight = torch.nn.Parameter(torch.zeros_like(linear.weight)) | ||
elif weight_value == "range": | ||
shape_w = linear.weight.shape | ||
weight_local = torch.ones(shape_w, dtype=torch.int32) | ||
range_t_weight = torch.tensor(list(range(0, 8)), dtype=torch.int32) | ||
linear.weight = torch.nn.Parameter((torch.ones(weight_local.numel(), dtype=linear.weight.dtype).reshape(-1, 8) * range_t_weight).reshape(shape_w).contiguous()) | ||
elif weight_value.isnumeric(): | ||
linear.weight = torch.nn.Parameter(torch.full_like(linear.weight, int(weight_value))) | ||
linear.weight = torch.nn.Parameter(linear.weight.to(weight_dtype)) | ||
|
||
if zeros_value == "zeros": | ||
zeros = torch.full((infeatures // group_size, outfeatures), 0, dtype=torch.int32) | ||
elif zeros_value == "range": | ||
zeros = torch.ones((infeatures // group_size, outfeatures), dtype=torch.int32) | ||
range_t_zeros = torch.tensor(list(range(1, 9)), dtype=torch.int32) | ||
shape_z = zeros.shape | ||
zeros = (torch.ones(zeros.numel(), dtype=torch.int32).reshape(-1, 8) * range_t_zeros).reshape(shape_z).contiguous() | ||
elif zeros_value == "eights": | ||
zeros = torch.full((infeatures // group_size, outfeatures), 8, dtype=torch.int32) | ||
else: | ||
zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32) | ||
|
||
htcore.mark_step() | ||
|
||
quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) | ||
htcore.mark_step() | ||
quant_ref_cuda_old.to("hpu") | ||
|
||
#TODO: pack independently | ||
quant_hpu.set_packed(quant_ref_cuda_old) | ||
htcore.mark_step() | ||
quant_hpu.to("hpu") | ||
|
||
out_ref_cuda_old = quant_ref_cuda_old(input) | ||
htcore.mark_step() | ||
quant_hpu.post_init() | ||
htcore.mark_step() | ||
out_hpu = quant_hpu(input) | ||
htcore.mark_step() | ||
|
||
out_ref_cuda_old = out_ref_cuda_old.cpu() | ||
out_hpu = out_hpu.cpu() | ||
compare_tensors(out_hpu.cpu(), out_ref_cuda_old.cpu(), rtol = 1e-05, atol = 1e-08) |
Oops, something went wrong.