Skip to content

Commit

Permalink
bitsandbytes - Linear8bitLt integration into transformers models (
Browse files Browse the repository at this point in the history
huggingface#17901)

* first commit

* correct replace function

* add final changes

- works like charm!
- cannot implement tests yet
- tested

* clean up a bit

* add bitsandbytes dependencies

* working version

- added import function
- added bitsandbytes utils file

* small fix

* small fix

- fix import issue

* fix import issues

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit

- move bitsandbytes utils to utils
- change comments on functions

* reformat docstring

- reformat docstring on init_empty_weights_8bit

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* revert bad formatting

* change to bitsandbytes

* refactor a bit

- remove init8bit since it is useless

* more refactoring

- fixed init empty weights issue
- added threshold param

* small hack to make it work

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* revmoe the small hack

* modify utils file

* make style + refactor a bit

* create correctly device map

* add correct dtype for device map creation

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

- remove with torch.grad
- do not rely on Python bool magic!

* add docstring

 - add docstring for new kwargs

* add docstring

- comment `replace_8bit_linear` function
- fix weird formatting

* - added more documentation
- added new utility function for memory footprint tracking
- colab demo to add

* few modifs

- typo doc
- force cast into float16 when load_in_8bit is enabled

* added colab link

* add test architecture + docstring a bit

* refactor a bit testing class

* make style + refactor a bit

* enhance checks

- add more checks
- start writing saving test

* clean up a bit

* male style

* add more details on doc

* add more tests

- still needs to fix 2 tests

* replace by "or"

- could not fix it from GitHub GUI

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit testing code + add readme

* make style

* fix import issue

* Update src/transformers/modeling_utils.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* add few comments

* add more doctring + make style

* more docstring

* raise error when loaded in 8bit

* make style

* add warning if loaded on CPU

* add small sanity check

* fix small comment

* add bitsandbytes on dockerfile

* Improve documentation

- improve documentation from comments

* add few comments

* slow tests pass on the VM but not on the CI VM

* Fix merge conflict

* make style

* another test should pass on a multi gpu setup

* fix bad import in testing file

* Fix slow tests

- remove dummy batches
- no more CUDA illegal memory errors

* odify dockerfile

* Update docs/source/en/main_classes/model.mdx

* Update Dockerfile

* Update model.mdx

* Update Dockerfile

* Apply suggestions from code review

* few modifications

- lm head can stay on disk/cpu
- change model name so that test pass

* change test value

- change test value to the correct output
- torch bmm changed to baddmm in bloom modeling when merging

* modify installation guidelines

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* replace `n`by `name`

* merge `load_in_8bit` and `low_cpu_mem_usage`

* first try - keep the lm head in full precision

* better check

- check the attribute `base_model_prefix` instead of computing the number of parameters

* added more tests

* Update src/transformers/utils/bitsandbytes.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit

* improve documentation

- fix typos for installation
- change title in the documentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
  • Loading branch information
3 people authored and amyeroberts committed Oct 18, 2022
1 parent 20e866e commit 586d858
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 176 deletions.
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"

RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate

# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5

RUN python3 -m pip install --no-cache-dir decord

# When installing in editable mode, `transformers` is not recognized as a package.
Expand Down
39 changes: 39 additions & 0 deletions docs/source/en/main_classes/model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,45 @@ model = AutoModel.from_config(config)

Due to Pytorch design, this functionality is only available for floating dtypes.

### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition

From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision.
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.

![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png)

Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).
Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).

Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.

Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)

#### Requirements

- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
- Install the correct version of `bitsandbytes` by running:
`pip install -i https://test.pypi.org/simple/ bitsandbytes`
- Install `accelerate`:
`pip install accelerate`

#### Running mixed-int8 models

After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
```py
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
```
The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows:
(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`)
```py
max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)
```

## ModuleUtilsMixin

Expand Down
39 changes: 10 additions & 29 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@
else:
get_balanced_memory = None

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -1789,7 +1787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
load_in_8bit_threshold (`float`, *optional*, defaults to 6):
int8_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
Expand All @@ -1799,9 +1797,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
load_in_8bit_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarily at the last position.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1893,8 +1888,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

Expand Down Expand Up @@ -2207,18 +2201,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = cls(config, *model_args, **model_kwargs)

if load_in_8bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear

logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if load_in_8bit_skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules
model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)

if isinstance(device_map, str):
if model._no_split_modules is None:
Expand Down Expand Up @@ -2249,18 +2236,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

if load_in_8bit:
# The LM head / tied weights or any last module can stay on disk / CPU
# The LM head can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized model. If you have set a value for `max_memory` you should increase that. To have
an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
"""
)
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
del device_map_without_lm_head

if from_tf:
Expand Down
33 changes: 6 additions & 27 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from copy import deepcopy

from transformers.utils import is_accelerate_available, is_bitsandbytes_available


Expand All @@ -11,7 +9,6 @@

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters


def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
Expand Down Expand Up @@ -114,7 +111,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)

if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
if isinstance(module, nn.Linear) and name != modules_to_not_convert:
with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
Expand All @@ -126,38 +123,20 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
return model


def get_keys_to_not_convert(model):
def get_key_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
we may want to keep the lm_head in full precision for numerical stability reasons.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()

tied_keys = list(find_tied_parameters(tied_model).values())
has_tied_params = len(tied_keys) > 0

# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)

# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
if not hasattr(model, model.base_model_prefix):
return ""

# otherwise they have an attached head
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]

# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)

return [module_name.split(".")[0] for module_name in list_untouched]
last_name = list_modules[-1][0]
return last_name.split(".")[0]
117 changes: 17 additions & 100 deletions tests/mixed_int8/README.md
Original file line number Diff line number Diff line change
@@ -1,120 +1,37 @@
# Testing mixed int8 quantization

![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1660567705337-62441d1d9fdefb55a0b7d12c.png)

The following is the recipe on how to effectively debug `bitsandbytes` integration on Hugging Face `transformers`.

## Library requirements

+ `transformers>=4.22.0`
+ `accelerate>=0.12.0`
+ `bitsandbytes>=0.31.5`.
## Hardware requirements

The following instructions are tested with 2 NVIDIA-Tesla T4 GPUs. To run successfully `bitsandbytes` you would need a 8-bit core tensor supported GPU. Note that Turing, Ampere or newer architectures - e.g. T4, RTX20s RTX30s, A40-A100, A6000 should be supported.
I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB

## Virutal envs

```bash
conda create --name int8-testing python==3.8
pip install bitsandbytes>=0.31.5
pip install accelerate>=0.12.0
pip install transformers>=4.23.0
```
if `transformers>=4.23.0` is not released yet, then use:
```
pip install git+https://github.com/huggingface/transformers.git
```

## Troubleshooting

A list of common errors:
```conda create --name int8-testing python==3.8```
```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit```
```pip install -e ".[dev]"```
```pip install -i https://test.pypi.org/simple/ bitsandbytes```
```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1```

### Torch does not correctly do the operations on GPU

First check that:
## Trobleshooting

```py
import torch

vec = torch.randn(1, 2, 3).to(0)
```
```conda create --name int8-testing python==3.8```
```pip install -i https://test.pypi.org/simple/ bitsandbytes```
```conda install pytorch torchvision torchaudio -c pytorch```
```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit```
```pip install -e ".[dev]"```
```pip install git+https://github.com/huggingface/accelerate.git@b52b793ea8bac108ba61192eead3cf11ca02433d```

Works without any error. If not, install torch using `conda` like:
### Check driver settings:

```bash
conda create --name int8-testing python==3.8
conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge
pip install bitsandbytes>=0.31.5
pip install accelerate>=0.12.0
pip install transformers>=4.23.0
```
For the latest pytorch instructions please see [this](https://pytorch.org/get-started/locally/)

and the snippet above should work.

### ` bitsandbytes operations are not supported under CPU!`

This happens when some Linear weights are set to the CPU when using `accelerate`. Please check carefully `model.hf_device_map` and make sure that there is no `Linear` module that is assigned to CPU. It is fine to have the last module (usually the Lm_head) set on CPU.

### `To use the type as a Parameter, please correct the detach() semantics defined by __torch_dispatch__() implementation.`

Use the latest version of `accelerate` with a command such as: `pip install -U accelerate` and the problem should be solved.

### `Parameter has no attribue .CB`

Same solution as above.

### `RuntimeError: CUDA error: an illegal memory access was encountered ... consider passing CUDA_LAUNCH_BLOCKING=1`

Run your script by pre-pending `CUDA_LAUNCH_BLOCKING=1` and you should observe an error as described in the next section.

### `CUDA illegal memory error: an illegal memory access at line...`:

Check the CUDA verisons with:
```
nvcc --version
```
and confirm it is the same version as the one detected by `bitsandbytes`. If not, run:
```
ls -l $CONDA_PREFIX/lib/libcudart.so
```
or
```
ls -l $LD_LIBRARY_PATH
```
Check if `libcudart.so` has a correct symlink that is set. Sometimes `nvcc` detects the correct CUDA version but `bitsandbytes` doesn't. You have to make sure that the symlink that is set for the file `libcudart.so` is redirected to the correct CUDA file.

Here is an example of a badly configured CUDA installation:

`nvcc --version` gives:

![Screenshot 2022-08-15 at 15.12.23.png](https://s3.amazonaws.com/moonup/production/uploads/1660569220888-62441d1d9fdefb55a0b7d12c.png)

which means that the detected CUDA version is 11.3 but `bitsandbytes` outputs:

![image.png](https://s3.amazonaws.com/moonup/production/uploads/1660569284243-62441d1d9fdefb55a0b7d12c.png)

First check:

```bash
echo $LD_LIBRARY_PATH
```

If this contains multiple paths separated by `:`. Then you have to make sure that the correct CUDA version is set. By doing:

```bash
ls -l $path/libcudart.so
```

On each path (`$path`) separated by `:`.
If not, simply run
```bash
ls -l $LD_LIBRARY_PATH/libcudart.so
ls -l $CONDA_PREFIX/lib/libcudart.so
```

and you can see

![Screenshot 2022-08-15 at 15.12.33.png](https://s3.amazonaws.com/moonup/production/uploads/1660569176504-62441d1d9fdefb55a0b7d12c.png)
### Recurrent bugs

If you see that the file is linked to the wrong CUDA version (here 10.2), find the correct location for `libcudart.so` (`find --name libcudart.so`) and replace the environment variable `LD_LIBRARY_PATH` with the one containing the correct `libcudart.so` file.
Sometimes you have to run a "dummy" inference pass when dealing with a multi-GPU setup. Checkout the ```test_multi_gpu_loading``` and the ```test_pipeline``` functions.
Loading

0 comments on commit 586d858

Please sign in to comment.