diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py index 2318dd6b..e3d47113 100644 --- a/auto_gptq/modeling/_utils.py +++ b/auto_gptq/modeling/_utils.py @@ -480,11 +480,14 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in model_uses_exllamav2 = False for _, submodule in model.named_modules(): - if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2": - model_uses_exllamav2 = True - device = submodule.qweight.device - scratch_fixed = submodule.scratch_space_fixed() - fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) + if hasattr(submodule, "QUANT_TYPE"): + if submodule.QUANT_TYPE == "exllamav2": + model_uses_exllamav2 = True + device = submodule.qweight.device + scratch_fixed = submodule.scratch_space_fixed() + fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) + elif submodule.QUANT_TYPE == "hpu": + submodule.post_init() if model_uses_exllamav2: from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors diff --git a/auto_gptq/nn_modules/qlinear/qlinear_hpu.py b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py new file mode 100644 index 00000000..0c998da2 --- /dev/null +++ b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py @@ -0,0 +1,149 @@ +import math +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn +import transformers +import habana_frameworks.torch.core as htcore + +logger = getLogger(__name__) + +def pack_tensor(input, bits = 4): + normal = input.to(torch.int32) + q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + +class QuantLinear(nn.Module): + QUANT_TYPE = "hpu" + + def __init__( + self, + bits, + group_size, + infeatures, + outfeatures, + bias, + use_cuda_fp16=True, + kernel_switch_threshold=128, + trainable=False, + weight_dtype=torch.float16, + ): + logger.debug(f"qlinear_hpu QuantLinear::__init__ {bits=}, {group_size=}, {infeatures=}, {outfeatures=}, {bias=}, {use_cuda_fp16=}, {kernel_switch_threshold=}, {trainable=}, {weight_dtype=}") + super().__init__() + if bits != 4: + raise NotImplementedError("Only 4 bits are supported.") + + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + + self.register_buffer( + "qweight", + torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + math.ceil(infeatures / self.group_size), + outfeatures // 32 * self.bits, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(infeatures / self.group_size), outfeatures), + dtype=weight_dtype, + ), + ) + self.register_buffer( + "g_idx", + torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + ) + + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) + else: + self.bias = None + self.half_indim = self.infeatures // 2 + + self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0) + + def _preprocessing(self): + self.qweight = self.qweight.cpu() + weight = self.unpack_weight_from_cuda_old_format() + new_qweight = pack_tensor(weight) + self.qweight = new_qweight.to('hpu') + + # TODO: Support group indexing and remove the check + columns = self.qweight.shape[0] + g_idx_trivial = [i // self.group_size for i in range(columns)] + g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32) + assert torch.equal(self.g_idx, g_idx_trivial), "Non-trivial tensor g_idx is not supported" + + zeros = self.unpack_zeros_from_cuda_old_format().cpu() + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to('hpu') + + def post_init(self): + self._preprocessing() + + def pack(self, linear, scales, zeros, g_idx): + #TODO: implement + raise NotImplementedError("QuantLinear HPU currently doesn't support packing") + + def set_packed(self, qlinear_cls): + self.qweight = qlinear_cls.qweight + self.qzeros = qlinear_cls.qzeros + self.scales = qlinear_cls.scales + self.bias = qlinear_cls.bias + + def forward(self, x): + x_dtype = x.dtype + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + scales = self.scales + qweight = self.qweight + zeros = self.qzeros + weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, x_dtype) + out = torch.matmul(x, weight) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + def unpack_zeros_from_cuda_old_format(self): + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and( + zeros, (2**self.bits) - 1 + ).to(self.scales.dtype) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) + return zeros + + def unpack_weight_from_cuda_old_format(self): + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + weight = weight.reshape((weight.shape[0]*weight.shape[1], weight.shape[2])) + return weight + +__all__ = ["QuantLinear"] diff --git a/auto_gptq/utils/import_utils.py b/auto_gptq/utils/import_utils.py index fe9aabf7..bc27a994 100644 --- a/auto_gptq/utils/import_utils.py +++ b/auto_gptq/utils/import_utils.py @@ -67,6 +67,13 @@ def dynamically_import_QuantLinear( use_marlin: bool = False, use_tritonv2: bool = False, ): + try: + import habana_frameworks.torch.hpu # noqa: F401 + except ImportError as e: + pass + else: + from ..nn_modules.qlinear.qlinear_hpu import QuantLinear + return QuantLinear if use_qigen: if not QIGEN_AVAILABLE: raise ValueError( diff --git a/auto_gptq/utils/peft_utils.py b/auto_gptq/utils/peft_utils.py index 6030ba74..6befb17f 100644 --- a/auto_gptq/utils/peft_utils.py +++ b/auto_gptq/utils/peft_utils.py @@ -13,6 +13,7 @@ from ..nn_modules.qlinear import GeneralQuantLinear from ..nn_modules.qlinear.qlinear_cuda import QuantLinear as QuantLinearCuda from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld +from ..nn_modules.qlinear.qlinear_hpu import QuantLinear as QuantLinearHpu from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllama from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllamaV2 from ..nn_modules.qlinear.qlinear_qigen import QuantLinear as QuantLinearQigen @@ -24,6 +25,7 @@ GeneralQuantLinear, QuantLinearCuda, QuantLinearCudaOld, + QuantLinearHpu, QuantLinearExllama, QuantLinearExllamaV2, QuantLinearQigen, @@ -145,6 +147,7 @@ def _create_new_module( GeneralQuantLinear, QuantLinearCuda, QuantLinearCudaOld, + QuantLinearHpu, QuantLinearExllama, QuantLinearExllamaV2, QuantLinearQigen, @@ -270,6 +273,7 @@ def _create_new_module( GeneralQuantLinear, QuantLinearCuda, QuantLinearCudaOld, + QuantLinearHpu, QuantLinearExllama, QuantLinearExllamaV2, QuantLinearQigen, diff --git a/tests/test_hpu_linear.py b/tests/test_hpu_linear.py new file mode 100644 index 00000000..0f141d01 --- /dev/null +++ b/tests/test_hpu_linear.py @@ -0,0 +1,178 @@ +import numpy as np +import math +import torch +import pytest +import habana_frameworks.torch.core as htcore + +def _convert_to_tensor_list(tensor_or_tensors): + if isinstance(tensor_or_tensors, tuple): + return list(tensor_or_tensors) + elif isinstance(tensor_or_tensors, list): + return tensor_or_tensors + elif isinstance(tensor_or_tensors, torch.Tensor): + # You can't return list(tensor_or_tensors), because it will fail on 0-d tensors + result_list = [] + result_list.append(tensor_or_tensors) + return result_list + else: + raise TypeError("Can not convert outputs") + +def compare_tensors(hpu_tensors, cpu_tensors, atol, rtol, assert_enable=True): + hpu_tensors = _convert_to_tensor_list(hpu_tensors) + cpu_tensors = _convert_to_tensor_list(cpu_tensors) + assert len(hpu_tensors) == len(cpu_tensors) + + hpu_tensors = [tensor.to('cpu') if tensor is not None else tensor for tensor in hpu_tensors] + + for i in range(len(hpu_tensors)): + if cpu_tensors[i] is None and hpu_tensors[i] is None: + continue + + hpu_tensors[i] = ( + hpu_tensors[i].float() + if hpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn] + else hpu_tensors[i] + ) + cpu_tensors[i] = ( + cpu_tensors[i].float() + if cpu_tensors[i].dtype in [torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn] + else cpu_tensors[i] + ) + if assert_enable: + np.testing.assert_allclose( + hpu_tensors[i].detach().numpy(), + cpu_tensors[i].detach().numpy(), + atol=atol, + rtol=rtol, + ) + else: + print("hpu_result[{}]".format(i), hpu_tensors[i].detach().numpy()) + print("cpu_result[{}]".format(i), cpu_tensors[i].detach().numpy()) + return np.allclose( + hpu_tensors[i].detach().numpy(), + cpu_tensors[i].detach().numpy(), + atol=atol, + rtol=rtol, + equal_nan=True, + ) + +# taken from AutoGPTQ/tests/test_repacking.py +def gen_quant4(k, n, groupsize=-1, bias=False): + maxq = 2 ** 4 - 1 + w = torch.randn((k, n), dtype=torch.bfloat16, device="cpu") + + original_w = w.clone() + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq + 1) // 2 + w = torch.clamp(w, 0, maxq) + + # Dequantize. + ref = (w - (maxq + 1) // 2).bfloat16() * s + + if groupsize != -1: + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = torch.nn.Linear(k, n, bias=bias) + linear.weight.data = ref.t() + + return original_w, linear, s + +@pytest.mark.parametrize("bits", [4]) +@pytest.mark.parametrize("group_size", [16, 32, 128]) +@pytest.mark.parametrize("infeatures", [64, 128, 512, 4096, 11008]) +@pytest.mark.parametrize("outfeatures", [64, 128, 512, 4096, 11008]) +@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("scales_value, weight_value, zeros_value", [("normal", "normal", "normal"), ("normal", "normal", "range"), ("normal", "normal", "zeros"), ("ones", "zeros", "zeros"), ("ones", "zeros", "eights"), ("ones", "range", "zeros"), ("ones", "range", "ones"), ("ones", "7", "ones"), ("ones", "zeros", "range"),("ones", "zeros", "ones"), ("ones", "range", "range"), ("range", "range", "range"), ("range", "range", "zeros")]) +@pytest.mark.parametrize("weight_dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) +def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_value, weight_value, zeros_value, weight_dtype): + qweight_shape_0 = infeatures // 32 * bits + qzeros_shape_0 = math.ceil(infeatures / group_size) + qzeros_shape_1 = outfeatures // 32 * bits + if qweight_shape_0 == 0 or qzeros_shape_0 == 0 or qzeros_shape_1 == 0: + pytest.skip(f"{qweight_shape_0=} == 0 or {qzeros_shape_0=} == 0 or {qzeros_shape_1=} == 0") + if infeatures < group_size: + pytest.skip(f"{infeatures=} < {group_size=}") + if infeatures != outfeatures: + pytest.skip(f"{infeatures=} != {outfeatures=}") + trainable = False + use_cuda_fp16 = False + kernel_switch_threshold = 128 + from auto_gptq.nn_modules.qlinear import qlinear_hpu, qlinear_cuda_old + quant_hpu = qlinear_hpu.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu") + # Cuda old implementation is the reference, also runs on hpu + quant_ref_cuda_old = qlinear_cuda_old.QuantLinear(bits=bits, group_size=group_size, infeatures=infeatures, outfeatures=outfeatures, bias=bias, use_cuda_fp16=use_cuda_fp16, kernel_switch_threshold=kernel_switch_threshold, trainable=trainable, weight_dtype=weight_dtype).to("hpu") + input = torch.rand((infeatures, outfeatures), dtype=weight_dtype).to("hpu") + _, linear, s = gen_quant4(infeatures, outfeatures, group_size, bias) + + if scales_value == "ones": + s = torch.ones_like(s) + if scales_value == "range": + range_t = torch.tensor(list(range(1, s.numel()+1)), dtype=torch.int32) + shape_s = s.shape + s = (torch.ones(s.numel()) * range_t).reshape(shape_s).contiguous() + + if weight_value == "ones": + linear.weight = torch.nn.Parameter(torch.ones_like(linear.weight)) + elif weight_value == "zeros": + linear.weight = torch.nn.Parameter(torch.zeros_like(linear.weight)) + elif weight_value == "range": + shape_w = linear.weight.shape + weight_local = torch.ones(shape_w, dtype=torch.int32) + range_t_weight = torch.tensor(list(range(0, 8)), dtype=torch.int32) + linear.weight = torch.nn.Parameter((torch.ones(weight_local.numel(), dtype=linear.weight.dtype).reshape(-1, 8) * range_t_weight).reshape(shape_w).contiguous()) + elif weight_value.isnumeric(): + linear.weight = torch.nn.Parameter(torch.full_like(linear.weight, int(weight_value))) + linear.weight = torch.nn.Parameter(linear.weight.to(weight_dtype)) + + if zeros_value == "zeros": + zeros = torch.full((infeatures // group_size, outfeatures), 0, dtype=torch.int32) + elif zeros_value == "range": + zeros = torch.ones((infeatures // group_size, outfeatures), dtype=torch.int32) + range_t_zeros = torch.tensor(list(range(1, 9)), dtype=torch.int32) + shape_z = zeros.shape + zeros = (torch.ones(zeros.numel(), dtype=torch.int32).reshape(-1, 8) * range_t_zeros).reshape(shape_z).contiguous() + elif zeros_value == "eights": + zeros = torch.full((infeatures // group_size, outfeatures), 8, dtype=torch.int32) + else: + zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32) + + htcore.mark_step() + + quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) + htcore.mark_step() + quant_ref_cuda_old.to("hpu") + + #TODO: pack independently + quant_hpu.set_packed(quant_ref_cuda_old) + htcore.mark_step() + quant_hpu.to("hpu") + + out_ref_cuda_old = quant_ref_cuda_old(input) + htcore.mark_step() + quant_hpu.post_init() + htcore.mark_step() + out_hpu = quant_hpu(input) + htcore.mark_step() + + out_ref_cuda_old = out_ref_cuda_old.cpu() + out_hpu = out_hpu.cpu() + compare_tensors(out_hpu.cpu(), out_ref_cuda_old.cpu(), rtol = 1e-05, atol = 1e-08) diff --git a/tests/test_q4.py b/tests/test_q4.py index f96e9bd8..a0383a78 100644 --- a/tests/test_q4.py +++ b/tests/test_q4.py @@ -7,6 +7,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear from auto_gptq.utils.import_utils import dynamically_import_QuantLinear +import habana_frameworks.torch.core as htcore try: @@ -2117,6 +2118,7 @@ def test_bias(self): self.assertTrue(predicted_text.startswith("Today I am in Paris and I am a student of the Master's")) + class TestsQ4Triton(unittest.TestCase): def test_generation_no_act_order(self): prompt = "I am in Paris and" @@ -2193,3 +2195,138 @@ def test_generation_with_act_order(self): predicted_text = tokenizer.decode(res[0]) self.assertEqual(predicted_text, reference_output) + + +class TestQ4HPU(unittest.TestCase): + @parameterized.expand( + [ + ("hpu", torch.bfloat16), + ("hpu", torch.float), + ] + ) + def test_generation(self, in_device, model_dtype): + # Reference generated with the cuda-old kernel and TheBloke/Llama-2-7B-Chat-GPTQ + reference_output = " I am in Paris and I am feeling very sad and lonely. everybody I know is busy and I don't have any friends here. I am staying in a small apartment in the 11th arrondissement and I am feeling very isolated. I miss my friends and family back home and I don'" + + prompt = "I am in Paris and" + device = torch.device(in_device) + + model_id = "TheBloke/Llama-2-7B-Chat-GPTQ" + + try: + from transformers import GPTQConfig, AutoModelForCausalLM + quantization_config = GPTQConfig(bits=4, use_exllama=False) + model_kwargs = { + "revision": "main", + "token": None + } + model_q = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs) + model_q = model_q.eval().to(device) + except ValueError as e: + if torch.version.hip: + self.assertTrue("Can not use HPU int4 kernel" in e.text) + self.skipTest("Can not run this test on HPU") + else: + raise e + + tokenizer_kwargs = { + "revision": "main", + "token": None + } + tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) + + if not model_q.config.is_encoder_decoder: + tokenizer.padding_side = "left" + # Some models like GPT2 do not have a PAD token so we have to set it if necessary + if model_q.config.model_type == "llama": + # unwind broken decapoda-research config + model_q.generation_config.pad_token_id = 0 + model_q.generation_config.bos_token_id = 1 + model_q.generation_config.eos_token_id = 2 + tokenizer.bos_token_id = model_q.generation_config.bos_token_id + tokenizer.eos_token_id = model_q.generation_config.eos_token_id + tokenizer.pad_token_id = model_q.generation_config.pad_token_id + tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id) + tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id) + tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model_q.generation_config.pad_token_id = model_q.generation_config.eos_token_id + + + inp = tokenizer(prompt, return_tensors="pt").to(device) + + res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) + + predicted_text = tokenizer.decode(res[0]) + + self.assertEqual(predicted_text, reference_output) + + @parameterized.expand( + [ + ("hpu", torch.bfloat16), + ("hpu", torch.float), + ] + ) + def test_bias(self, in_device, model_dtype): + device = torch.device(in_device) + # TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias. + model_id = "s3nh/starcoderbase-1b-GPTQ" + try: + model_kwargs = { + "revision": "main", + "token": None + } + model_q = AutoGPTQForCausalLM.from_quantized(model_id, torch_dtype=model_dtype, use_marlin=False, **model_kwargs) + model_q = model_q.eval().to(device) + except ValueError as e: + if torch.version.hip: + self.assertTrue("Can not use HPU int4 kernel" in e.text) + self.skipTest("Can not run this test on HPU") + else: + raise e + + for _, param in model_q.named_parameters(): + self.assertTrue(param.device != torch.device("meta")) + + for _, param in model_q.named_buffers(): + self.assertTrue(param.device != torch.device("meta")) + + self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_proj.bias) > 0) + self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_attn.bias) > 0) + + tokenizer_kwargs = { + "revision": "main", + "token": None + } + tokenizer = AutoTokenizer.from_pretrained("Xenova/starcoderbase-1b", **tokenizer_kwargs) + + if not model_q.config.is_encoder_decoder: + tokenizer.padding_side = "left" + # Some models like GPT2 do not have a PAD token so we have to set it if necessary + if model_q.config.model_type == "llama": + # unwind broken decapoda-research config + model_q.generation_config.pad_token_id = 0 + model_q.generation_config.bos_token_id = 1 + model_q.generation_config.eos_token_id = 2 + tokenizer.bos_token_id = model_q.generation_config.bos_token_id + tokenizer.eos_token_id = model_q.generation_config.eos_token_id + tokenizer.pad_token_id = model_q.generation_config.pad_token_id + tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id) + tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id) + tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model_q.generation_config.pad_token_id = model_q.generation_config.eos_token_id + + prompt = "Today I am in Paris and" + inp = tokenizer(prompt, return_tensors="pt").to(device) + + res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) + + predicted_text = tokenizer.decode(res[0]) + + self.assertTrue(predicted_text.startswith("Today I am in Paris and I am a student of the Master's")) +