Skip to content

Commit

Permalink
Fix marlin quantization using BaseQuantizeConfig from _pretrained Aut…
Browse files Browse the repository at this point in the history
…oGPTQ#581 (AutoGPTQ#586)

* Fix 1. model.quantize() skipped marlin config. 2. In qlinear_marlin.pack() linear.bias may be None

* add marlin quantization test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
  • Loading branch information
Qubitium and fxmarty committed Mar 19, 2024
1 parent 09289d8 commit 04bca0d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
1 change: 1 addition & 0 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def tmp(_, inp, out):
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
is_marlin_format=self.quantize_config.is_marlin_format,
)
if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
Expand Down
9 changes: 7 additions & 2 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def pack_model(
desc_act=False,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
is_marlin_format: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
Expand All @@ -255,7 +256,7 @@ def pack_model(
bits=bits,
disable_exllama=False,
disable_exllamav2=True,
disable_marlin=True,
disable_marlin=not is_marlin_format,
)

if force_layer_back_to_cpu:
Expand All @@ -274,6 +275,7 @@ def pack_model(
desc_act=desc_act,
disable_exllama=False,
disable_exllamav2=True,
use_marlin=is_marlin_format,
)
qlayers = find_layers(model, [QuantLinear])
for name in qlayers:
Expand All @@ -288,7 +290,10 @@ def pack_model(
zero.to(CPU),
g_idx.to(CPU),
)
qlayers[name].pack(layers[name], scale, zero, g_idx)
if QuantLinear.QUANT_TYPE == "marlin":
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
logger.info("Model packed.")

Expand Down
7 changes: 5 additions & 2 deletions auto_gptq/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def pack(self, linear, scales):
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
self.B[:, :] = q.to(self.B.device)
self.s[:, :] = s.to(self.s.device)
if self.bias is not None:
self.bias[:] = linear.bias.data.to(self.bias.device)
if linear.bias is not None:
if self.bias is not None:
self.bias[:] = linear.bias.data.to(self.bias.device)
else:
self.bias = linear.bias.clone()

def forward(self, A):
A = A.half()
Expand Down
7 changes: 6 additions & 1 deletion tests/test_quantization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import tempfile
import unittest

from parameterized import parameterized
from transformers import AutoTokenizer

from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig


class TestQuantization(unittest.TestCase):
def test_quantize(self):
@parameterized.expand([(False,), (True,)])
def test_quantize(self, use_marlin: bool):
pretrained_model_dir = "saibo/llama-1B"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
Expand All @@ -20,6 +22,7 @@ def test_quantize(self):
bits=4,
group_size=128,
desc_act=False,
is_marlin_format=use_marlin,
)

model = AutoGPTQForCausalLM.from_pretrained(
Expand All @@ -32,3 +35,5 @@ def test_quantize(self):

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", use_marlin=use_marlin)

0 comments on commit 04bca0d

Please sign in to comment.