Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bitsandbytes - Linear8bitLt integration into transformers models #17901

Merged
merged 91 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
ed1dd12
first commit
younesbelkada Jun 24, 2022
b9d0da6
correct replace function
younesbelkada Jun 24, 2022
dd9a464
add final changes
younesbelkada Jun 27, 2022
35e1534
clean up a bit
younesbelkada Jun 27, 2022
d01822b
add bitsandbytes dependencies
younesbelkada Jul 7, 2022
839c9cd
working version
younesbelkada Jul 12, 2022
93a5ac6
Merge branch 'main' of https://github.com/huggingface/transformers in…
younesbelkada Jul 12, 2022
42a6845
small fix
younesbelkada Jul 12, 2022
97f64f8
small fix
younesbelkada Jul 12, 2022
a1fe7fc
fix import issues
younesbelkada Jul 12, 2022
05739e3
Apply suggestions from code review
younesbelkada Jul 12, 2022
7816ef9
refactor a bit
younesbelkada Jul 12, 2022
1155549
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 12, 2022
b222b9a
reformat docstring
younesbelkada Jul 12, 2022
32f48cd
Update src/transformers/__init__.py
younesbelkada Jul 12, 2022
e116e21
revert bad formatting
younesbelkada Jul 12, 2022
39c46a0
change to bitsandbytes
younesbelkada Jul 12, 2022
b92c25c
refactor a bit
younesbelkada Jul 12, 2022
3779f5d
more refactoring
younesbelkada Jul 13, 2022
b41c250
small hack to make it work
younesbelkada Jul 19, 2022
311dcbf
Merge branch 'main' into integration-8bit
younesbelkada Jul 20, 2022
c91a58e
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
db16cf8
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
be6ce29
revmoe the small hack
younesbelkada Jul 21, 2022
848d64d
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 21, 2022
15a81e0
modify utils file
younesbelkada Jul 21, 2022
514758d
make style + refactor a bit
younesbelkada Jul 22, 2022
a09e055
create correctly device map
younesbelkada Jul 25, 2022
387aa1e
add correct dtype for device map creation
younesbelkada Jul 25, 2022
7199292
Apply suggestions from code review
younesbelkada Jul 27, 2022
a4c19c1
apply suggestions
younesbelkada Jul 27, 2022
a5cd157
add docstring
younesbelkada Jul 27, 2022
9bd326b
add docstring
younesbelkada Jul 27, 2022
6e4fee6
- added more documentation
younesbelkada Jul 28, 2022
9e121b7
few modifs
younesbelkada Jul 28, 2022
a2ac688
added colab link
younesbelkada Jul 28, 2022
ac370b9
add test architecture + docstring a bit
younesbelkada Jul 28, 2022
27e9486
refactor a bit testing class
younesbelkada Jul 28, 2022
a0db982
make style + refactor a bit
younesbelkada Jul 28, 2022
21bd590
enhance checks
younesbelkada Jul 29, 2022
f1fcf77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 29, 2022
9b81c67
clean up a bit
younesbelkada Jul 29, 2022
659f427
male style
younesbelkada Jul 29, 2022
147683e
add more details on doc
younesbelkada Jul 29, 2022
c2e1918
add more tests
younesbelkada Aug 1, 2022
f6eb945
replace by "or"
younesbelkada Aug 1, 2022
ceff43e
refactor a bit testing code + add readme
younesbelkada Aug 1, 2022
55cec55
make style
younesbelkada Aug 1, 2022
67bf4fb
fix import issue
younesbelkada Aug 1, 2022
56e9147
Update src/transformers/modeling_utils.py
younesbelkada Aug 1, 2022
5a03a86
add few comments
younesbelkada Aug 1, 2022
7abb914
add more doctring + make style
younesbelkada Aug 1, 2022
961e57e
more docstring
younesbelkada Aug 1, 2022
1326a42
raise error when loaded in 8bit
younesbelkada Aug 1, 2022
9a0051b
make style
younesbelkada Aug 1, 2022
59f9a5a
add warning if loaded on CPU
younesbelkada Aug 2, 2022
418690b
add small sanity check
younesbelkada Aug 2, 2022
6d93424
fix small comment
younesbelkada Aug 2, 2022
d428d8d
add bitsandbytes on dockerfile
younesbelkada Aug 2, 2022
0324f4b
Improve documentation
younesbelkada Aug 2, 2022
af229c4
Merge branch 'main' into integration-8bit
younesbelkada Aug 2, 2022
70ad8cb
add few comments
younesbelkada Aug 3, 2022
1eedb90
slow tests pass on the VM but not on the CI VM
younesbelkada Aug 3, 2022
163ef77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 3, 2022
8100d03
Fix merge conflict
younesbelkada Aug 3, 2022
c9589f6
make style
younesbelkada Aug 3, 2022
eb9a26d
another test should pass on a multi gpu setup
younesbelkada Aug 4, 2022
838e2a9
fix bad import in testing file
younesbelkada Aug 4, 2022
c4a1e9b
Fix slow tests
younesbelkada Aug 4, 2022
31fce94
odify dockerfile
younesbelkada Aug 4, 2022
3e4a2a4
Update docs/source/en/main_classes/model.mdx
younesbelkada Aug 5, 2022
fdf37b3
Update Dockerfile
younesbelkada Aug 5, 2022
53e0b2e
Update model.mdx
younesbelkada Aug 7, 2022
91364c4
Update Dockerfile
younesbelkada Aug 8, 2022
5ea8976
Apply suggestions from code review
younesbelkada Aug 8, 2022
8b72d08
few modifications
younesbelkada Aug 8, 2022
32a4863
Merge branch 'main' into integration-8bit
younesbelkada Aug 8, 2022
c6c139f
change test value
younesbelkada Aug 8, 2022
3d3224f
Merge remote-tracking branch 'upstream/main' into integration-8bit
younesbelkada Aug 9, 2022
0d5bc2b
modify installation guidelines
younesbelkada Aug 9, 2022
bc8f332
Apply suggestions from code review
younesbelkada Aug 9, 2022
5adcadc
Apply suggestions from code review
younesbelkada Aug 9, 2022
bb00e7a
Apply suggestions from code review
younesbelkada Aug 9, 2022
630b4f7
replace `n`by `name`
younesbelkada Aug 9, 2022
a925641
merge `load_in_8bit` and `low_cpu_mem_usage`
younesbelkada Aug 9, 2022
279b8c4
first try - keep the lm head in full precision
younesbelkada Aug 9, 2022
e49a2ea
better check
younesbelkada Aug 9, 2022
5718d78
added more tests
younesbelkada Aug 9, 2022
a40667a
Update src/transformers/utils/bitsandbytes.py
younesbelkada Aug 9, 2022
61faa28
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 9, 2022
36fbbd2
improve documentation
younesbelkada Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"

RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate

# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5

RUN python3 -m pip install --no-cache-dir decord

# When installing in editable mode, `transformers` is not recognized as a package.
Expand Down
41 changes: 40 additions & 1 deletion docs/source/en/main_classes/model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
```

Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.

### Model Instantiation dtype

Expand Down Expand Up @@ -133,6 +133,45 @@ model = AutoModel.from_config(config)

Due to Pytorch design, this functionality is only available for floating dtypes.

### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition

From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision.
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.

![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png)

Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).
Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).

Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.

Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)

#### Requirements

- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
- Install the correct version of `bitsandbytes` by running:
`pip install -i https://test.pypi.org/simple/ bitsandbytes`
- Install `accelerate`:
`pip install accelerate`

#### Running mixed-int8 models

After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
```py
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
```
The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows:
(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`)
```py
max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)
```


## ModuleUtilsMixin
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
"is_vision_available",
"logging",
],
"utils.bitsandbytes": [],
}

# sentencepiece-backed objects
Expand Down
102 changes: 95 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
copy_func,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
logging,
replace_return_docstrings,
Expand All @@ -83,6 +84,9 @@
else:
get_balanced_memory = None

if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -500,6 +504,7 @@ def _load_state_dict_into_meta_model(
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -560,13 +565,14 @@ def _load_state_dict_into_meta_model(
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]

if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
else:
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)

return error_msgs, offload_index, state_dict_index

Expand Down Expand Up @@ -1577,6 +1583,24 @@ def save_pretrained(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2

Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
Expand Down Expand Up @@ -1706,6 +1730,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
load_in_8bit (`bool`, *optional*, defaults to `False`):
If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
install `bitsandbytes` compiled with your CUDA version by running `pip install -i
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
int8_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
values will be done in fp16. Values are usually normally distributed, that is, most values are in the
range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently
distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1795,15 +1835,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored."
)

if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
Expand All @@ -1823,6 +1864,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)

if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
if torch_dtype == "auto" or torch_dtype != torch.float16:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
torch_dtype = torch.float16
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)

from_pt = not (from_tf | from_flax)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
Expand Down Expand Up @@ -2062,12 +2125,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif low_cpu_mem_usage:
elif load_in_8bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

if load_in_8bit:
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)

if isinstance(device_map, str):
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
Expand All @@ -2090,9 +2160,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(
model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory
model,
no_split_module_classes=no_split_modules,
dtype=torch_dtype if not load_in_8bit else torch.int8,
max_memory=max_memory,
)

if load_in_8bit:
# The LM head can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
del device_map_without_lm_head

if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
Expand Down Expand Up @@ -2144,6 +2226,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
)

# make sure token embedding weights are still tied if needed
Expand Down Expand Up @@ -2184,6 +2267,7 @@ def _load_pretrained_model(
offload_folder=None,
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
):
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand Down Expand Up @@ -2249,7 +2333,10 @@ def _fix_key(key):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))

# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
Expand Down Expand Up @@ -2358,6 +2445,7 @@ def _find_mismatched_keys(
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
)
error_msgs += new_error_msgs
else:
Expand Down
Loading