Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bitsandbytes - Linear8bitLt integration into transformers models #17901

Merged
merged 91 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
ed1dd12
first commit
younesbelkada Jun 24, 2022
b9d0da6
correct replace function
younesbelkada Jun 24, 2022
dd9a464
add final changes
younesbelkada Jun 27, 2022
35e1534
clean up a bit
younesbelkada Jun 27, 2022
d01822b
add bitsandbytes dependencies
younesbelkada Jul 7, 2022
839c9cd
working version
younesbelkada Jul 12, 2022
93a5ac6
Merge branch 'main' of https://github.com/huggingface/transformers in…
younesbelkada Jul 12, 2022
42a6845
small fix
younesbelkada Jul 12, 2022
97f64f8
small fix
younesbelkada Jul 12, 2022
a1fe7fc
fix import issues
younesbelkada Jul 12, 2022
05739e3
Apply suggestions from code review
younesbelkada Jul 12, 2022
7816ef9
refactor a bit
younesbelkada Jul 12, 2022
1155549
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 12, 2022
b222b9a
reformat docstring
younesbelkada Jul 12, 2022
32f48cd
Update src/transformers/__init__.py
younesbelkada Jul 12, 2022
e116e21
revert bad formatting
younesbelkada Jul 12, 2022
39c46a0
change to bitsandbytes
younesbelkada Jul 12, 2022
b92c25c
refactor a bit
younesbelkada Jul 12, 2022
3779f5d
more refactoring
younesbelkada Jul 13, 2022
b41c250
small hack to make it work
younesbelkada Jul 19, 2022
311dcbf
Merge branch 'main' into integration-8bit
younesbelkada Jul 20, 2022
c91a58e
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
db16cf8
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
be6ce29
revmoe the small hack
younesbelkada Jul 21, 2022
848d64d
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 21, 2022
15a81e0
modify utils file
younesbelkada Jul 21, 2022
514758d
make style + refactor a bit
younesbelkada Jul 22, 2022
a09e055
create correctly device map
younesbelkada Jul 25, 2022
387aa1e
add correct dtype for device map creation
younesbelkada Jul 25, 2022
7199292
Apply suggestions from code review
younesbelkada Jul 27, 2022
a4c19c1
apply suggestions
younesbelkada Jul 27, 2022
a5cd157
add docstring
younesbelkada Jul 27, 2022
9bd326b
add docstring
younesbelkada Jul 27, 2022
6e4fee6
- added more documentation
younesbelkada Jul 28, 2022
9e121b7
few modifs
younesbelkada Jul 28, 2022
a2ac688
added colab link
younesbelkada Jul 28, 2022
ac370b9
add test architecture + docstring a bit
younesbelkada Jul 28, 2022
27e9486
refactor a bit testing class
younesbelkada Jul 28, 2022
a0db982
make style + refactor a bit
younesbelkada Jul 28, 2022
21bd590
enhance checks
younesbelkada Jul 29, 2022
f1fcf77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 29, 2022
9b81c67
clean up a bit
younesbelkada Jul 29, 2022
659f427
male style
younesbelkada Jul 29, 2022
147683e
add more details on doc
younesbelkada Jul 29, 2022
c2e1918
add more tests
younesbelkada Aug 1, 2022
f6eb945
replace by "or"
younesbelkada Aug 1, 2022
ceff43e
refactor a bit testing code + add readme
younesbelkada Aug 1, 2022
55cec55
make style
younesbelkada Aug 1, 2022
67bf4fb
fix import issue
younesbelkada Aug 1, 2022
56e9147
Update src/transformers/modeling_utils.py
younesbelkada Aug 1, 2022
5a03a86
add few comments
younesbelkada Aug 1, 2022
7abb914
add more doctring + make style
younesbelkada Aug 1, 2022
961e57e
more docstring
younesbelkada Aug 1, 2022
1326a42
raise error when loaded in 8bit
younesbelkada Aug 1, 2022
9a0051b
make style
younesbelkada Aug 1, 2022
59f9a5a
add warning if loaded on CPU
younesbelkada Aug 2, 2022
418690b
add small sanity check
younesbelkada Aug 2, 2022
6d93424
fix small comment
younesbelkada Aug 2, 2022
d428d8d
add bitsandbytes on dockerfile
younesbelkada Aug 2, 2022
0324f4b
Improve documentation
younesbelkada Aug 2, 2022
af229c4
Merge branch 'main' into integration-8bit
younesbelkada Aug 2, 2022
70ad8cb
add few comments
younesbelkada Aug 3, 2022
1eedb90
slow tests pass on the VM but not on the CI VM
younesbelkada Aug 3, 2022
163ef77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 3, 2022
8100d03
Fix merge conflict
younesbelkada Aug 3, 2022
c9589f6
make style
younesbelkada Aug 3, 2022
eb9a26d
another test should pass on a multi gpu setup
younesbelkada Aug 4, 2022
838e2a9
fix bad import in testing file
younesbelkada Aug 4, 2022
c4a1e9b
Fix slow tests
younesbelkada Aug 4, 2022
31fce94
odify dockerfile
younesbelkada Aug 4, 2022
3e4a2a4
Update docs/source/en/main_classes/model.mdx
younesbelkada Aug 5, 2022
fdf37b3
Update Dockerfile
younesbelkada Aug 5, 2022
53e0b2e
Update model.mdx
younesbelkada Aug 7, 2022
91364c4
Update Dockerfile
younesbelkada Aug 8, 2022
5ea8976
Apply suggestions from code review
younesbelkada Aug 8, 2022
8b72d08
few modifications
younesbelkada Aug 8, 2022
32a4863
Merge branch 'main' into integration-8bit
younesbelkada Aug 8, 2022
c6c139f
change test value
younesbelkada Aug 8, 2022
3d3224f
Merge remote-tracking branch 'upstream/main' into integration-8bit
younesbelkada Aug 9, 2022
0d5bc2b
modify installation guidelines
younesbelkada Aug 9, 2022
bc8f332
Apply suggestions from code review
younesbelkada Aug 9, 2022
5adcadc
Apply suggestions from code review
younesbelkada Aug 9, 2022
bb00e7a
Apply suggestions from code review
younesbelkada Aug 9, 2022
630b4f7
replace `n`by `name`
younesbelkada Aug 9, 2022
a925641
merge `load_in_8bit` and `low_cpu_mem_usage`
younesbelkada Aug 9, 2022
279b8c4
first try - keep the lm head in full precision
younesbelkada Aug 9, 2022
e49a2ea
better check
younesbelkada Aug 9, 2022
5718d78
added more tests
younesbelkada Aug 9, 2022
a40667a
Update src/transformers/utils/bitsandbytes.py
younesbelkada Aug 9, 2022
61faa28
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 9, 2022
36fbbd2
improve documentation
younesbelkada Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion docs/source/en/main_classes/model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
```

Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.

### Model Instantiation dtype

Expand Down Expand Up @@ -133,6 +133,31 @@ model = AutoModel.from_config(config)

Due to Pytorch design, this functionality is only available for floating dtypes.

### Mixed int8 quantization 🪶

From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization.
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.
Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.

Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

#### Requirements

- Make sure you run that on a NVIDIA T4 or A100 GPU that supports 8-bit tensor cores. Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
- Install the correct version of `bitsandbytes` by running:
`pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding to your `CUDA` version. E.g, if you have `CUDA` 11.3 installed on your machine, run `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be updated to just pip install bitsandbytes

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
- Install `accelerate`:
`pip install accelerate`

#### Running mixed-int8 models

After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
```
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
```
The implementation supports multi-GPU setup thanks to `accelerate` as backend.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved


## ModuleUtilsMixin
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@
"is_vision_available",
"logging",
],
"utils.bitsandbytes": [],
}

# sentencepiece-backed objects
Expand Down
86 changes: 80 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
has_file,
hf_bucket_url,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
is_remote_url,
logging,
Expand All @@ -82,6 +83,9 @@
set_module_tensor_to_device,
)

if is_bitsandbytes_available():
from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -490,6 +494,7 @@ def _load_state_dict_into_meta_model(
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -548,13 +553,14 @@ def _load_state_dict_into_meta_model(
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]

if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
else:
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)

return error_msgs, offload_index, state_dict_index

Expand Down Expand Up @@ -1564,6 +1570,24 @@ def save_pretrained(
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Model pushed to the hub in this commit: {url}")

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2

Arguments:
return_buffers (`bool`, *optional*):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
Expand Down Expand Up @@ -1691,6 +1715,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
load_in_8bit (`bool`, *optional*):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
install `bitsandbytes` compiled with your CUDA version by running `pip install -i
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
int8_threshold (`int`, *optional*):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper (#TODO provide
link to paper). Set to `6.0` by default as described in the paper.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1779,9 +1813,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")

if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
Expand All @@ -1801,6 +1836,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)

if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX`"
" correspondingto your CUDA version (e.g. `pip install -i https://test.pypi.org/simple/"
" bitsandbytes-cuda113` for CUDA 11.3)"
)
if torch_dtype == "auto" or torch_dtype != torch.float16:
torch_dtype = (
torch.float16
) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
" surethe weights are in PyTorch format."
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
)

from_pt = not (from_tf | from_flax)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
Expand Down Expand Up @@ -2083,20 +2142,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif load_in_8bit:
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
init_contexts = [init_empty_weights()] # Force enable init empty weights
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

if load_in_8bit:
model = replace_8bit_linear(model, threshold=int8_threshold)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

if device_map == "auto":
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
no_split_modules = model._no_split_modules
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(
model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory
model,
no_split_module_classes=no_split_modules,
dtype=torch_dtype if not load_in_8bit else torch.int8,
max_memory=max_memory,
)

if from_tf:
Expand Down Expand Up @@ -2150,6 +2218,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
)

# make sure token embedding weights are still tied if needed
Expand Down Expand Up @@ -2190,6 +2259,7 @@ def _load_pretrained_model(
offload_folder=None,
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
):
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand Down Expand Up @@ -2255,7 +2325,10 @@ def _fix_key(key):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))

# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
Expand Down Expand Up @@ -2364,6 +2437,7 @@ def _find_mismatched_keys(
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
)
error_msgs += new_error_msgs
else:
Expand Down
112 changes: 112 additions & 0 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from transformers.utils import is_accelerate_available, is_bitsandbytes_available


if is_bitsandbytes_available():
import torch
import torch.nn as nn

import bitsandbytes as bnb

if is_accelerate_available():
from accelerate import init_empty_weights


def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
class `Int8Params` from `bitsandbytes`.

Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
tensor_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]

if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
Comment on lines +31 to +44
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am wrong, but I think you can make it simpler:

if tensor_name in [t[0] for t in module.named_parameters()]:
    old_value = module.get_parameter(tensor_name)
elif tensor_name in [t[0] for t in module.named_buffers()]:
    old_value = module.get_buffer(tensor_name)
else:
    raise ValueError ...


if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

if is_buffer:
has_fp16_weights = None
else:
has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None)

if has_fp16_weights is not None:
param = module._parameters[tensor_name]
if param.device.type != "cuda":
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
else:
new_value = torch.tensor(value, device="cpu")
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device)
module._parameters[tensor_name] = new_value
else:
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)

if is_buffer:
module._buffers[tensor_name] = new_value
else:
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value


def replace_8bit_linear(model, threshold=6.0):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new release now works with just a pip install. And it is no longer needed to compile bitsandbytes. For the full release (done in a couple of days) the library would also reside directly on pip. So best to change this already to: pip install bitsandbytes

It might be useful to add one line about how it works before linking to the GPT3.int8() paper:

Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).


Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
for n, module in model.named_children():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be better written as a torch.fx transformation, but this way works for more models since not all models are currently traceable for torch.fx.

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)

if isinstance(module, nn.Linear) and n != "lm_head":
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
with init_empty_weights():
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
return model
13 changes: 13 additions & 0 deletions tests/mixed_int8/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Testing mixed int8 quantization

## Hardware requirements

I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB

## Virutal envs

```conda create --name int8-testing python==3.8```
```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit```
```pip install -e ".[dev]"```
```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed the way bitsandbytes is installed. This line no longer uses the cuda suffix:
pip install -i https://test.pypi.org/simple/ bitsandbytes

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1```
Empty file added tests/mixed_int8/__init__.py
Empty file.
Loading