Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add possibility of skipping modules when quantizing #248

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions awq/models/_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import json
import logging
from typing import Dict
from typing import Dict, Optional, List
from dataclasses import dataclass, field, fields
from transformers.utils.hub import PushToHubMixin, cached_file

Expand All @@ -13,6 +13,7 @@ class AwqConfig(PushToHubMixin):
w_bit: int = field(default=4)
version: str = field(default="GEMM")
config_file_name = "quant_config.json"
modules_to_not_convert: Optional[List] = None

def save_pretrained(self, save_dir: str, **kwargs):
logging.warning(
Expand Down Expand Up @@ -76,7 +77,8 @@ def to_dict(self):
"zero_point": self.zero_point,
"q_group_size": self.q_group_size,
"w_bit": self.w_bit,
"version": self.version
"version": self.version,
"modules_to_not_convert": self.modules_to_not_convert,
}

def to_transformers_dict(self):
Expand All @@ -86,4 +88,5 @@ def to_transformers_dict(self):
"group_size": self.q_group_size,
"bits": self.w_bit,
"version": self.version.lower(),
"modules_to_not_convert": self.modules_to_not_convert,
}
4 changes: 2 additions & 2 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def generate(self, *args, **kwargs):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", duo_scaling=True):
split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column, duo_scaling
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
)
quantizer.quantize()
self.is_quantized = True
Expand Down
14 changes: 13 additions & 1 deletion awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling) -> None:
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
Expand All @@ -25,6 +25,7 @@ def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
self.split = split
self.text_column = text_column
self.duo_scaling = duo_scaling
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
self.modules, self.module_kwargs, self.inps = self.init_quant()

def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
Expand Down Expand Up @@ -68,6 +69,13 @@ def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: to

return w

def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers

def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
Expand All @@ -80,6 +88,10 @@ def quantize(self):

# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])

# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)

input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()

Expand Down
6 changes: 4 additions & 2 deletions awq/quantize/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
# Skip the modules that are not quantized
if layer_name in input_feat_dict:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))

prev_op.cpu()
for layer in layers:
Expand Down