Skip to content

Commit

Permalink
Supporting uint4 inference of pre-quantized models in HPU (AutoGPTQ#689)
Browse files Browse the repository at this point in the history
* Supporting llama uint4 quantization using AutoGPTQ (AutoGPTQ#1)

* Supporting llama int4 quantization using AutoGPTQ

* Running only PT code (similar to cuda_old) on HPU

* Testing convert_from_int4

* Started cleanup

* code cleanup

* Added weight reshape in preprocessing
Added llama7b generation hpu test

* Changed reshape to match matmul (still not accurate) and fixed q4 test

* Fixing zero points

* Update pack function

* Fixed accuracy

* Uncommented exllama

* Marlin test fix + added hpu bias test

* Review comments

* Removed hpu pack until we'll implement it in HPU

---------

Co-authored-by: yan tomsinsky <ytomsinsky@habana.ai>

* Added assert when g_idx is not trivial (AutoGPTQ#2)

---------

Co-authored-by: yan tomsinsky <ytomsinsky@habana.ai>
  • Loading branch information
HolyFalafel and Yantom1 committed Jun 24, 2024
1 parent ea829c7 commit b57bea0
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 5 deletions.
13 changes: 8 additions & 5 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,14 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
model_uses_exllamav2 = False

for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))
if hasattr(submodule, "QUANT_TYPE"):
if submodule.QUANT_TYPE == "exllamav2":
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))
elif submodule.QUANT_TYPE == "hpu":
submodule.post_init()

if model_uses_exllamav2:
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
Expand Down
149 changes: 149 additions & 0 deletions auto_gptq/nn_modules/qlinear/qlinear_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import math
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import transformers
import habana_frameworks.torch.core as htcore

logger = getLogger(__name__)

def pack_tensor(input, bits = 4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q

class QuantLinear(nn.Module):
QUANT_TYPE = "hpu"

def __init__(
self,
bits,
group_size,
infeatures,
outfeatures,
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128,
trainable=False,
weight_dtype=torch.float16,
):
logger.debug(f"qlinear_hpu QuantLinear::__init__ {bits=}, {group_size=}, {infeatures=}, {outfeatures=}, {bias=}, {use_cuda_fp16=}, {kernel_switch_threshold=}, {trainable=}, {weight_dtype=}")
super().__init__()
if bits != 4:
raise NotImplementedError("Only 4 bits are supported.")

self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2**self.bits - 1

self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=weight_dtype,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
)

if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None
self.half_indim = self.infeatures // 2

self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)

def _preprocessing(self):
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to('hpu')

# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.group_size for i in range(columns)]
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
assert torch.equal(self.g_idx, g_idx_trivial), "Non-trivial tensor g_idx is not supported"

zeros = self.unpack_zeros_from_cuda_old_format().cpu()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to('hpu')

def post_init(self):
self._preprocessing()

def pack(self, linear, scales, zeros, g_idx):
#TODO: implement
raise NotImplementedError("QuantLinear HPU currently doesn't support packing")

def set_packed(self, qlinear_cls):
self.qweight = qlinear_cls.qweight
self.qzeros = qlinear_cls.qzeros
self.scales = qlinear_cls.scales
self.bias = qlinear_cls.bias

def forward(self, x):
x_dtype = x.dtype
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
scales = self.scales
qweight = self.qweight
zeros = self.qzeros
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, x_dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)

zeros = zeros + 1
zeros = torch.bitwise_and(
zeros, (2**self.bits) - 1
).to(self.scales.dtype) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros

def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0]*weight.shape[1], weight.shape[2]))
return weight

__all__ = ["QuantLinear"]
7 changes: 7 additions & 0 deletions auto_gptq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def dynamically_import_QuantLinear(
use_marlin: bool = False,
use_tritonv2: bool = False,
):
try:
import habana_frameworks.torch.hpu # noqa: F401
except ImportError as e:
pass
else:
from ..nn_modules.qlinear.qlinear_hpu import QuantLinear
return QuantLinear
if use_qigen:
if not QIGEN_AVAILABLE:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions auto_gptq/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear as QuantLinearCuda
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld
from ..nn_modules.qlinear.qlinear_hpu import QuantLinear as QuantLinearHpu
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllama
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllamaV2
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear as QuantLinearQigen
Expand All @@ -24,6 +25,7 @@
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
Expand Down Expand Up @@ -145,6 +147,7 @@ def _create_new_module(
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
Expand Down Expand Up @@ -270,6 +273,7 @@ def _create_new_module(
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
Expand Down
178 changes: 178 additions & 0 deletions tests/test_hpu_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import numpy as np
import math
import torch
import pytest
import habana_frameworks.torch.core as htcore

def _convert_to_tensor_list(tensor_or_tensors):
if isinstance(tensor_or_tensors, tuple):
return list(tensor_or_tensors)
elif isinstance(tensor_or_tensors, list):
return tensor_or_tensors
elif isinstance(tensor_or_tensors, torch.Tensor):
# You can't return list(tensor_or_tensors), because it will fail on 0-d tensors
result_list = []
result_list.append(tensor_or_tensors)
return result_list
else:
raise TypeError("Can not convert outputs")

def compare_tensors(hpu_tensors, cpu_tensors, atol, rtol, assert_enable=True):
hpu_tensors = _convert_to_tensor_list(hpu_tensors)
cpu_tensors = _convert_to_tensor_list(cpu_tensors)
assert len(hpu_tensors) == len(cpu_tensors)

hpu_tensors = [tensor.to('cpu') if tensor is not None else tensor for tensor in hpu_tensors]

for i in range(len(hpu_tensors)):
if cpu_tensors[i] is None and hpu_tensors[i] is None:
continue

hpu_tensors[i] = (
hpu_tensors[i].float()
if hpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn]
else hpu_tensors[i]
)
cpu_tensors[i] = (
cpu_tensors[i].float()
if cpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn]
else cpu_tensors[i]
)
if assert_enable:
np.testing.assert_allclose(
hpu_tensors[i].detach().numpy(),
cpu_tensors[i].detach().numpy(),
atol=atol,
rtol=rtol,
)
else:
print("hpu_result[{}]".format(i), hpu_tensors[i].detach().numpy())
print("cpu_result[{}]".format(i), cpu_tensors[i].detach().numpy())
return np.allclose(
hpu_tensors[i].detach().numpy(),
cpu_tensors[i].detach().numpy(),
atol=atol,
rtol=rtol,
equal_nan=True,
)

# taken from AutoGPTQ/tests/test_repacking.py
def gen_quant4(k, n, groupsize=-1, bias=False):
maxq = 2 ** 4 - 1
w = torch.randn((k, n), dtype=torch.bfloat16, device="cpu")

original_w = w.clone()

if groupsize != -1:
w = w.reshape((-1, groupsize, n))
w = w.permute(1, 0, 2)
w = w.reshape((groupsize, -1))

s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / maxq

# Quantize.
w = torch.round(w / s).int()

# Unsigned storage.
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)

# Dequantize.
ref = (w - (maxq + 1) // 2).bfloat16() * s

if groupsize != -1:
def reshape(w):
w = w.reshape((groupsize, -1, n))
w = w.permute(1, 0, 2)
w = w.reshape((k, n)).contiguous()
return w
ref = reshape(ref)
w = reshape(w)

s = s.reshape((-1, n)).contiguous()
linear = torch.nn.Linear(k, n, bias=bias)
linear.weight.data = ref.t()

return original_w, linear, s

@pytest.mark.parametrize("bits", [4])
@pytest.mark.parametrize("group_size", [16, 32, 128])
@pytest.mark.parametrize("infeatures", [64, 128, 512, 4096, 11008])
@pytest.mark.parametrize("outfeatures", [64, 128, 512, 4096, 11008])
@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"])
@pytest.mark.parametrize("scales_value, weight_value, zeros_value", [("normal", "normal", "normal"), ("normal", "normal", "range"), ("normal", "normal", "zeros"), ("ones", "zeros", "zeros"), ("ones", "zeros", "eights"), ("ones", "range", "zeros"), ("ones", "range", "ones"), ("ones", "7", "ones"), ("ones", "zeros", "range"),("ones", "zeros", "ones"), ("ones", "range", "range"), ("range", "range", "range"), ("range", "range", "zeros")])
@pytest.mark.parametrize("weight_dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"])
def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_value, weight_value, zeros_value, weight_dtype):
qweight_shape_0 = infeatures // 32 * bits
qzeros_shape_0 = math.ceil(infeatures / group_size)
qzeros_shape_1 = outfeatures // 32 * bits
if qweight_shape_0 == 0 or qzeros_shape_0 == 0 or qzeros_shape_1 == 0:
pytest.skip(f"{qweight_shape_0=} == 0 or {qzeros_shape_0=} == 0 or {qzeros_shape_1=} == 0")
if infeatures < group_size:
pytest.skip(f"{infeatures=} < {group_size=}")
if infeatures != outfeatures:
pytest.skip(f"{infeatures=} != {outfeatures=}")
trainable = False
use_cuda_fp16 = False
kernel_switch_threshold = 128
from auto_gptq.nn_modules.qlinear import qlinear_hpu, qlinear_cuda_old
quant_hpu = qlinear_hpu.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu")
# Cuda old implementation is the reference, also runs on hpu
quant_ref_cuda_old = qlinear_cuda_old.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu")
input = torch.rand((infeatures, outfeatures), dtype=weight_dtype).to("hpu")
_, linear, s = gen_quant4(infeatures, outfeatures, group_size, bias)

if scales_value == "ones":
s = torch.ones_like(s)
if scales_value == "range":
range_t = torch.tensor(list(range(1, s.numel()+1)), dtype=torch.int32)
shape_s = s.shape
s = (torch.ones(s.numel()) * range_t).reshape(shape_s).contiguous()

if weight_value == "ones":
linear.weight = torch.nn.Parameter(torch.ones_like(linear.weight))
elif weight_value == "zeros":
linear.weight = torch.nn.Parameter(torch.zeros_like(linear.weight))
elif weight_value == "range":
shape_w = linear.weight.shape
weight_local = torch.ones(shape_w, dtype=torch.int32)
range_t_weight = torch.tensor(list(range(0, 8)), dtype=torch.int32)
linear.weight = torch.nn.Parameter((torch.ones(weight_local.numel(), dtype=linear.weight.dtype).reshape(-1, 8) * range_t_weight).reshape(shape_w).contiguous())
elif weight_value.isnumeric():
linear.weight = torch.nn.Parameter(torch.full_like(linear.weight, int(weight_value)))
linear.weight = torch.nn.Parameter(linear.weight.to(weight_dtype))

if zeros_value == "zeros":
zeros = torch.full((infeatures // group_size, outfeatures), 0, dtype=torch.int32)
elif zeros_value == "range":
zeros = torch.ones((infeatures // group_size, outfeatures), dtype=torch.int32)
range_t_zeros = torch.tensor(list(range(1, 9)), dtype=torch.int32)
shape_z = zeros.shape
zeros = (torch.ones(zeros.numel(), dtype=torch.int32).reshape(-1, 8) * range_t_zeros).reshape(shape_z).contiguous()
elif zeros_value == "eights":
zeros = torch.full((infeatures // group_size, outfeatures), 8, dtype=torch.int32)
else:
zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32)

htcore.mark_step()

quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None)
htcore.mark_step()
quant_ref_cuda_old.to("hpu")

#TODO: pack independently
quant_hpu.set_packed(quant_ref_cuda_old)
htcore.mark_step()
quant_hpu.to("hpu")

out_ref_cuda_old = quant_ref_cuda_old(input)
htcore.mark_step()
quant_hpu.post_init()
htcore.mark_step()
out_hpu = quant_hpu(input)
htcore.mark_step()

out_ref_cuda_old = out_ref_cuda_old.cpu()
out_hpu = out_hpu.cpu()
compare_tensors(out_hpu.cpu(), out_ref_cuda_old.cpu(), rtol = 1e-05, atol = 1e-08)
Loading

0 comments on commit b57bea0

Please sign in to comment.