Skip to content

Commit

Permalink
Accuracy results for scale estimation algorithm in docs. (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#2706)

### Changes

Added scale estimation accuracy results to documentation.
  • Loading branch information
andreyanufr committed Jun 4, 2024
1 parent 77a8867 commit cd64227
Showing 1 changed file with 94 additions and 4 deletions.
98 changes: 94 additions & 4 deletions docs/usage/post_training_compression/weights_compression/Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
```

- Accuracy of the 4-bit compressed models also can be improved by using AWQ algorithm over data-based mixed-precision algorithm. It is capable to equalize some subset of weights to minimize difference between
- Accuracy of the 4-bit compressed models also can be improved by using AWQ algorithm or Scale Estimation algorithm over data-based mixed-precision algorithm. It is capable to equalize some subset of weights to minimize difference between
original precision and 4-bit.
Below is the example how to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and AWQ.
It requires to set `awq` to `True` additionally to data-based mixed-precision algorithm.
Below is the example how to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and AWQ with Scale Estimation.
It requires to set `awq` to `True` and `scale_estimation` to `True` additionally to data-based mixed-precision algorithm.
Both algorithms, AWQ and Scale Estimation, can be enabled together or separately.

```python
from datasets import load_dataset
Expand Down Expand Up @@ -114,7 +115,12 @@ input_shapes = get_input_shapes(model)
nncf_dataset = Dataset(dataset, partial(transform_func, tokenizer=tokenizer,
input_shapes=input_shapes))

model.model = compress_weights(model.model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset, awq=True)
model.model = compress_weights(model.model,
mode=CompressWeightsMode.INT4_SYM,
ratio=0.8,
dataset=nncf_dataset,
awq=True,
scale_estimation=True)

model.save_pretrained(...)
```
Expand Down Expand Up @@ -389,6 +395,90 @@ This modification applies only for patterns `MatMul-Multiply-MatMul` (for exampl
</tr>
</table>

Here is the perplexity and accuracy with data-free and data-aware mixed-precision INT4-INT8 weight compression for different language models on the [lambada openai dataset](https://huggingface.co/datasets/EleutherAI/lambada_openai).
`_scale` suffix refers to the data-aware mixed-precision with Scale Estimation algorithm.
`r100` means that embeddings and lm_head have INT8 precision and all other linear layers have INT4 precision.
<table>
<tr bgcolor='#B4B5BB'>
<td>Model</td>
<td>Mode</td>
<td>Acc (↑)</td>
<td>Ppl (↓)</td>
</tr>
<tr>
<td>stabilityai_stablelm-2-zephyr-1_6b</td>
<td>fp32</td>
<td>0.5925</td>
<td>6.3024</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs64_scale</td>
<td>0.5795</td>
<td>7.3245</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs64</td>
<td>0.5465</td>
<td>8.649</td>
</tr>
<tr>
<td>stable-zephyr-3b-dpo</td>
<td>fp32</td>
<td>0.6099</td>
<td>6.7151</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs64_scale</td>
<td>0.595</td>
<td>7.037</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs64</td>
<td>0.5639</td>
<td>9.349</td>
</tr>
<tr>
<td>microsoft_Phi-3-mini-4k-instruct</td>
<td>fp32</td>
<td>0.6839</td>
<td>4.1681</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs128_scale</td>
<td>0.6736</td>
<td>4.4711</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs128</td>
<td>0.6342</td>
<td>5.3419</td>
</tr>
<tr>
<td>mistralai_Mistral-7B-v0.1</td>
<td>fp32</td>
<td>0.7592</td>
<td>3.1898</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs128_scale</td>
<td>0.7479</td>
<td>3.3527</td>
</tr>
<tr>
<td></td>
<td>int4_sym_r100_gs128</td>
<td>0.7421</td>
<td>3.4932</td>
<t/r>
</table>

#### Limitations

- The algorithm is supported for OpenVINO and PyTorch models.
Expand Down

0 comments on commit cd64227

Please sign in to comment.