Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched quantization #516

Merged
merged 21 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,27 @@ def quantize(
"Whether to apply clipping to the model during quantization. Some models may perform better with this set to False."
),
] = True,
n_parallel_calib_samples: Annotated[
int,
Doc(
"The number of parallel samples to run through the model. "
"A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. "
"If None, runs through all samples at the same time. "
"You can set this to a low number for more memory efficient quantization."
),
] = None,
max_calib_samples: Annotated[
int,
Doc(
"The maximum number of samples to run through the model."
)
] = 128,
max_calib_seq_len: Annotated[
int,
Doc(
"The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len."
)
] = 512
):
"""
The main quantization function that you can use to quantize your model.
Expand Down Expand Up @@ -193,6 +214,9 @@ def quantize(
modules_to_not_convert=self.quant_config.modules_to_not_convert,
export_compatible=export_compatible,
apply_clip=apply_clip,
n_parallel_calib_samples=n_parallel_calib_samples,
max_calib_samples=max_calib_samples,
max_calib_seq_len=max_calib_seq_len,
)
self.quantizer.quantize()

Expand Down
117 changes: 98 additions & 19 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(
modules_to_not_convert=None,
export_compatible=False,
apply_clip=True,
n_parallel_calib_samples=None,
max_calib_samples=128,
max_calib_seq_len=512,
) -> None:
self.awq_model = awq_model
self.model = model
Expand All @@ -55,10 +58,17 @@ def __init__(
self.duo_scaling = duo_scaling
self.export_compatible = export_compatible
self.apply_clip = apply_clip
self.n_parallel_calib_samples = n_parallel_calib_samples
self.max_calib_samples = max_calib_samples
self.max_calib_seq_len = max_calib_seq_len
self.max_chunk_memory = 1024 * 1024 * 1024 // 10
self.modules_to_not_convert = (
modules_to_not_convert if modules_to_not_convert is not None else []
)
self.modules, self.module_kwargs, self.inps = self.init_quant()
self.modules, self.module_kwargs, self.inps = self.init_quant(
n_samples=self.max_calib_samples,
max_seq_len=self.max_calib_seq_len
)

def pseudo_quantize_tensor(self, w: torch.Tensor):
org_w_shape = w.shape
Expand Down Expand Up @@ -207,7 +217,7 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):

elif self.version == "marlin":
q_linear_module = WQLinear_Marlin

elif self.version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

Expand All @@ -228,6 +238,32 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
set_op_by_name(module, name, q_linear)
clear_memory()

@torch.no_grad()
def _module_forward(
self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
) -> torch.Tensor:
if self.n_parallel_calib_samples is None:
# runs through all samples at once
module_output = module(x, **module_kwargs)
if isinstance(module_output, tuple):
module_output = module_output[0]
else:
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = []
for i in range(0, x.shape[0], self.n_parallel_calib_samples):
x_partial = x[i : i + self.n_parallel_calib_samples]
partial_output = module(x_partial, **module_kwargs)

if isinstance(partial_output, tuple):
partial_output = partial_output[0]

module_output.append(partial_output)

module_output = torch.cat(module_output, dim=0)
casper-hansen marked this conversation as resolved.
Show resolved Hide resolved

return module_output

@torch.no_grad()
def _search_best_scale(
self,
Expand All @@ -254,7 +290,7 @@ def _search_best_scale(
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.group_size)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
# Resizes the rescaled weight matrix back up to its original dimensions
Expand All @@ -263,16 +299,27 @@ def _search_best_scale(
w_mean = w_scale.mean(0)
clear_memory(weight)

# [STEP 2]: Compute per-channel mean of the input activation
x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 2]: Compute per-channel mean of the input activation with chunking
inp_flat = inp.abs().view(-1, inp.shape[-1])
num_elements = inp_flat.size(0)
num_channels = inp_flat.size(1)
element_size_bytes = inp_flat.element_size()

# Calculate chunk size dynamically based on max_chunk_memory
chunk_size = self.max_chunk_memory // (element_size_bytes * num_channels)
chunk_size = min(chunk_size, num_elements)

x_mean = torch.zeros(num_channels, dtype=inp.dtype, device=inp.device)
for i in range(0, num_elements, chunk_size):
end = min(i + chunk_size, num_elements)
x_mean += inp_flat[i:end].sum(dim=0)

x_mean /= num_elements

# [STEP 3]: Compute output of module
with torch.no_grad():
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
Expand All @@ -292,7 +339,7 @@ def _compute_best_scale(
x_mean,
module2inspect,
linears2scale: List[nn.Linear],
fp16_output,
fp16_output: torch.Tensor,
kwargs={},
):
"""
Expand Down Expand Up @@ -328,6 +375,10 @@ def _compute_best_scale(
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)

# avoid scaling values that overflow
scales[torch.isinf(scales)] = 1
scales[torch.isnan(scales)] = 1

# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
Expand All @@ -336,14 +387,10 @@ def _compute_best_scale(
)

# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
int_w_output = self._module_forward(x, module2inspect, kwargs)

# compute mean squared error (L2 norm)
loss = (
(fp16_output - int_w_output).float().pow(2).mean().item()
) # NOTE: float prevents overflow
loss = self._compute_loss(fp16_output, int_w_output)

history.append(loss)
if loss < best_error:
Expand All @@ -360,6 +407,38 @@ def _compute_best_scale(

return best_scales.detach().cpu()

@torch.no_grad()
def _compute_loss(
self,
fp16_output: torch.Tensor,
int_w_output: torch.Tensor,
):
loss = 0.0
fp16_output_flat = fp16_output.view(-1)
int_w_output_flat = int_w_output.view(-1)
num_elements = fp16_output_flat.size(0)
element_size_bytes = fp16_output.element_size()

# Calculate chunk size dynamically based on max_chunk_memory
# Divide the max_chunk_memory by twice the element size
chunk_size = self.max_chunk_memory // (element_size_bytes * 2)
chunk_size = min(chunk_size, num_elements)

# Chunk the computation
for i in range(0, num_elements, chunk_size):
fp16_chunk = fp16_output_flat[i:i+chunk_size]
int_w_chunk = int_w_output_flat[i:i+chunk_size]

# Compute the loss for the chunk
chunk_loss = (fp16_chunk - int_w_chunk).float().pow(2).sum().item()
loss += chunk_loss

# Normalize the loss by the total number of elements
loss /= num_elements

return loss


@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
Expand Down Expand Up @@ -436,13 +515,13 @@ def _compute_best_clip(

return best_max_val.squeeze(1)

def init_quant(self, n_samples=128, seqlen=512):
def init_quant(self, n_samples=128, max_seq_len=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
max_seq_len=max_seq_len,
split=self.split,
text_column=self.text_column,
)
Expand Down Expand Up @@ -536,7 +615,7 @@ def cache_input_hook(m, x, y, name, feat_dict):
# Useful for trust_remote_code models.
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

self.inps = layer(self.inps, **module_kwargs)[0]
self.inps = self._module_forward(self.inps, layer, module_kwargs)
for h in handles:
h.remove()
# now solve for scaling and clipping
Expand Down
12 changes: 6 additions & 6 deletions awq/utils/calib_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
def get_calib_dataset(
data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
n_samples=128,
max_seq_len=512,
split="train",
text_column="text",
):
Expand Down Expand Up @@ -47,7 +47,7 @@ def get_calib_dataset(
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
if len(line_encoded) > max_seq_len:
casper-hansen marked this conversation as resolved.
Show resolved Hide resolved
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
Expand All @@ -56,10 +56,10 @@ def get_calib_dataset(
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
# now concatenate all samples and split according to max sequence length
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
n_split = cat_samples.shape[1] // max_seq_len
logging.debug(f" * Split into {n_split} blocks")
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
cat_samples[:, i * max_seq_len : (i + 1) * max_seq_len] for i in range(n_split)
]
47 changes: 47 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,53 @@ tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```

#### Long-context and thousands of calibration samples

For this example, we will use HuggingFaceTB/cosmopedia-100k as it's a high-quality dataset and
we can filter directly on the number of tokens. We will use Qwen2 7B, one of the newer supported
models in AutoAWQ which is high-performing.

NOTE: Please make sure to properly adjust `n_parallel_calib_samples` to avoid OOM. If your sequence
length is long and you have many samples, it's very important to tune this parameter to avoid OOM.

```python
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'Qwen/Qwen2-7B-Instruct'
quant_path = 'qwen2-7b-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_cosmopedia():
data = load_dataset('HuggingFaceTB/cosmopedia-100k', split="train")
data = data.filter(lambda x: x["text_token_length"] >= 2048)

return [text for text in data["text"]]

# Quantize
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=load_cosmopedia(),
n_parallel_calib_samples=32,
max_calib_samples=1000,
max_calib_seq_len=4096
)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')
```

### GGUF Export

This computes AWQ scales and appliesthem to the model without running real quantization.
Expand Down