From 7444876962d3848710c0887996f1d36ba7a88e49 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 10 Aug 2022 09:13:36 +0200 Subject: [PATCH] `bitsandbytes` - `Linear8bitLt` integration into `transformers` models (#17901) * first commit * correct replace function * add final changes - works like charm! - cannot implement tests yet - tested * clean up a bit * add bitsandbytes dependencies * working version - added import function - added bitsandbytes utils file * small fix * small fix - fix import issue * fix import issues * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor a bit - move bitsandbytes utils to utils - change comments on functions * reformat docstring - reformat docstring on init_empty_weights_8bit * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * revert bad formatting * change to bitsandbytes * refactor a bit - remove init8bit since it is useless * more refactoring - fixed init empty weights issue - added threshold param * small hack to make it work * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * revmoe the small hack * modify utils file * make style + refactor a bit * create correctly device map * add correct dtype for device map creation * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply suggestions - remove with torch.grad - do not rely on Python bool magic! * add docstring - add docstring for new kwargs * add docstring - comment `replace_8bit_linear` function - fix weird formatting * - added more documentation - added new utility function for memory footprint tracking - colab demo to add * few modifs - typo doc - force cast into float16 when load_in_8bit is enabled * added colab link * add test architecture + docstring a bit * refactor a bit testing class * make style + refactor a bit * enhance checks - add more checks - start writing saving test * clean up a bit * male style * add more details on doc * add more tests - still needs to fix 2 tests * replace by "or" - could not fix it from GitHub GUI Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor a bit testing code + add readme * make style * fix import issue * Update src/transformers/modeling_utils.py Co-authored-by: Michael Benayoun * add few comments * add more doctring + make style * more docstring * raise error when loaded in 8bit * make style * add warning if loaded on CPU * add small sanity check * fix small comment * add bitsandbytes on dockerfile * Improve documentation - improve documentation from comments * add few comments * slow tests pass on the VM but not on the CI VM * Fix merge conflict * make style * another test should pass on a multi gpu setup * fix bad import in testing file * Fix slow tests - remove dummy batches - no more CUDA illegal memory errors * odify dockerfile * Update docs/source/en/main_classes/model.mdx * Update Dockerfile * Update model.mdx * Update Dockerfile * Apply suggestions from code review * few modifications - lm head can stay on disk/cpu - change model name so that test pass * change test value - change test value to the correct output - torch bmm changed to baddmm in bloom modeling when merging * modify installation guidelines * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * replace `n`by `name` * merge `load_in_8bit` and `low_cpu_mem_usage` * first try - keep the lm head in full precision * better check - check the attribute `base_model_prefix` instead of computing the number of parameters * added more tests * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit * improve documentation - fix typos for installation - change title in the documentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Michael Benayoun --- docker/transformers-all-latest-gpu/Dockerfile | 3 + docs/source/en/main_classes/model.mdx | 41 +++- src/transformers/__init__.py | 1 + src/transformers/modeling_utils.py | 102 ++++++++- src/transformers/utils/bitsandbytes.py | 142 ++++++++++++ tests/mixed_int8/README.md | 37 +++ tests/mixed_int8/__init__.py | 0 tests/mixed_int8/test_mixed_int8.py | 215 ++++++++++++++++++ utils/tests_fetcher.py | 1 + 9 files changed, 534 insertions(+), 8 deletions(-) create mode 100644 src/transformers/utils/bitsandbytes.py create mode 100644 tests/mixed_int8/README.md create mode 100644 tests/mixed_int8/__init__.py create mode 100644 tests/mixed_int8/test_mixed_int8.py diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index e97a91f4246fb4..b0a55ba8be946b 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0" RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate +# Add bitsandbytes for mixed int8 testing +RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5 + RUN python3 -m pip install --no-cache-dir decord # When installing in editable mode, `transformers` is not recognized as a package. diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index c59af2d2214814..10f81e55d74506 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1} ``` -Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`). +Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below. ### Model Instantiation dtype @@ -133,6 +133,45 @@ model = AutoModel.from_config(config) Due to Pytorch design, this functionality is only available for floating dtypes. +### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition + +From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. +For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision. +This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. + +![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png) + +Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters). +Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). + +Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature. + +Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing) + +#### Requirements + +- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores. +- Install the correct version of `bitsandbytes` by running: +`pip install -i https://test.pypi.org/simple/ bitsandbytes` +- Install `accelerate`: +`pip install accelerate` + +#### Running mixed-int8 models + +After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows: +```py +model_name = "bigscience/bloom-2b5" +model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) +``` +The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows: +(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) +```py +max_memory_mapping = {0: "1GB", 1: "2GB"} +model_name = "bigscience/bloom-3b" +model_8bit = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping +) +``` ## ModuleUtilsMixin diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0a97952b18b85e..be2be2727f0146 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -462,6 +462,7 @@ "is_vision_available", "logging", ], + "utils.bitsandbytes": [], } # sentencepiece-backed objects diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 78c012ec095fdb..1d895baecfedac 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -61,6 +61,7 @@ copy_func, has_file, is_accelerate_available, + is_bitsandbytes_available, is_offline_mode, logging, replace_return_docstrings, @@ -83,6 +84,9 @@ else: get_balanced_memory = None +if is_bitsandbytes_available(): + from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device + logger = logging.get_logger(__name__) @@ -501,6 +505,7 @@ def _load_state_dict_into_meta_model( state_dict_folder=None, state_dict_index=None, dtype=None, + load_in_8bit=False, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -561,13 +566,14 @@ def _load_state_dict_into_meta_model( # TODO: group all errors and raise at the end. raise ValueError(f"{param_name} doesn't have any device set.") param_device = device_map[module_name] - if param_device == "disk": offload_index = offload_weight(param, param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) - else: + elif not load_in_8bit: set_module_tensor_to_device(model, param_name, param_device, value=param) + else: + set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) return error_msgs, offload_index, state_dict_index @@ -1578,6 +1584,24 @@ def save_pretrained( save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token ) + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): r""" @@ -1707,6 +1731,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. + load_in_8bit (`bool`, *optional*, defaults to `False`): + If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please + install `bitsandbytes` compiled with your CUDA version by running `pip install -i + https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). + Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are + not compiled and adapted for CPUs. + int8_threshold (`float`, *optional*, defaults to 6): + Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as + described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden + states value that is above this threshold will be considered an outlier and the operation on those + values will be done in fp16. Values are usually normally distributed, that is, most values are in the + range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently + distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 + quantization works well for values of magnitude ~5, but beyond that, there is a significant performance + penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models + (small models, fine-tuning). subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. @@ -1796,7 +1836,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + load_in_8bit = kwargs.pop("load_in_8bit", False) + int8_threshold = kwargs.pop("int8_threshold", 6.0) subfolder = kwargs.pop("subfolder", "") if trust_remote_code is True: @@ -1804,7 +1846,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" " ignored." ) - if device_map is not None: if low_cpu_mem_usage is None: low_cpu_mem_usage = True @@ -1824,6 +1865,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`" ) + if load_in_8bit: + if not (is_accelerate_available() and is_bitsandbytes_available()): + raise ImportError( + "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" + " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or" + " pip install bitsandbytes` " + ) + if torch_dtype == "auto" or torch_dtype != torch.float16: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + torch_dtype = torch.float16 + logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") + if device_map is None: + raise ValueError( + "A device map needs to be passed to run convert models into mixed-int8 format. Please run" + "`.from_pretrained` with `device_map='auto'`" + ) + if from_tf or from_flax: + raise ValueError( + "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + from_pt = not (from_tf | from_flax) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -2063,12 +2126,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts - elif low_cpu_mem_usage: + elif load_in_8bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) + if load_in_8bit: + logger.info("Detected 8-bit loading: activating 8-bit loading for this model") + + # We never convert lm_head or any last modules for numerical stability reasons + modules_to_not_convert = get_key_to_not_convert(model) + model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) + if isinstance(device_map, str): if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.") @@ -2091,9 +2161,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Make sure tied weights are tied before creating the device map. model.tie_weights() device_map = infer_auto_device_map( - model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory + model, + no_split_module_classes=no_split_modules, + dtype=torch_dtype if not load_in_8bit else torch.int8, + max_memory=max_memory, ) + if load_in_8bit: + # The LM head can stay on disk / CPU + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert + } + if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") + del device_map_without_lm_head + if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors @@ -2145,6 +2227,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, + load_in_8bit=load_in_8bit, ) # make sure token embedding weights are still tied if needed @@ -2185,6 +2268,7 @@ def _load_pretrained_model( offload_folder=None, offload_state_dict=None, dtype=None, + load_in_8bit=False, ): if device_map is not None and "disk" in device_map.values(): if offload_folder is None: @@ -2250,7 +2334,10 @@ def _fix_key(key): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] if param.device == torch.device("meta"): - set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) + if not load_in_8bit: + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) + else: + set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: @@ -2359,6 +2446,7 @@ def _find_mismatched_keys( state_dict_folder=state_dict_folder, state_dict_index=state_dict_index, dtype=dtype, + load_in_8bit=load_in_8bit, ) error_msgs += new_error_msgs else: diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py new file mode 100644 index 00000000000000..ee4e52d421fd09 --- /dev/null +++ b/src/transformers/utils/bitsandbytes.py @@ -0,0 +1,142 @@ +from transformers.utils import is_accelerate_available, is_bitsandbytes_available + + +if is_bitsandbytes_available(): + import torch + import torch.nn as nn + + import bitsandbytes as bnb + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + if is_buffer: + has_fp16_weights = None + else: + has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None) + + if has_fp16_weights is not None: + param = module._parameters[tensor_name] + if param.device.type != "cuda": + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to("cpu") + if value.dtype == torch.int8: + raise ValueError( + "You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are", + " using `load_in_8bit=True` on float32/float16/bfloat16 weights.", + ) + else: + new_value = torch.tensor(value, device="cpu") + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) + module._parameters[tensor_name] = new_value + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + +def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes` + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a + matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 + (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no + predictive degradation is possible for very large models (>=176B parameters). + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + threshold (`float`, *optional*, defaults to 6.0): + `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to + `6.0` as described by the paper. + modules_to_not_convert (`str`, *optional*, defaults to `lm_head`): + Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module, threshold, modules_to_not_convert) + + if isinstance(module, nn.Linear) and name != modules_to_not_convert: + with init_empty_weights(): + model._modules[name] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=threshold, + ) + return model + + +def get_key_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Ignore this for base models (BertModel, GPT2Model, etc.) + if not hasattr(model, model.base_model_prefix): + return "" + + # otherwise they have an attached head + list_modules = list(model.named_parameters()) + last_name = list_modules[-1][0] + return last_name.split(".")[0] diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md new file mode 100644 index 00000000000000..c0173bed7a6b7a --- /dev/null +++ b/tests/mixed_int8/README.md @@ -0,0 +1,37 @@ +# Testing mixed int8 quantization + +## Hardware requirements + +I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB + +## Virutal envs + +```conda create --name int8-testing python==3.8``` +```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` +```pip install -e ".[dev]"``` +```pip install -i https://test.pypi.org/simple/ bitsandbytes``` +```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` + + +## Trobleshooting + +```conda create --name int8-testing python==3.8``` +```pip install -i https://test.pypi.org/simple/ bitsandbytes``` +```conda install pytorch torchvision torchaudio -c pytorch``` +```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` +```pip install -e ".[dev]"``` +```pip install git+https://github.com/huggingface/accelerate.git@b52b793ea8bac108ba61192eead3cf11ca02433d``` + +### Check driver settings: + +``` +nvcc --version +``` + +``` +ls -l $CONDA_PREFIX/lib/libcudart.so +``` + +### Recurrent bugs + +Sometimes you have to run a "dummy" inference pass when dealing with a multi-GPU setup. Checkout the ```test_multi_gpu_loading``` and the ```test_pipeline``` functions. \ No newline at end of file diff --git a/tests/mixed_int8/__init__.py b/tests/mixed_int8/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py new file mode 100644 index 00000000000000..0cd7ca16411c19 --- /dev/null +++ b/tests/mixed_int8/test_mixed_int8.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline +from transformers.testing_utils import ( + is_torch_available, + require_accelerate, + require_bitsandbytes, + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + slow, +) + + +if is_torch_available(): + import torch + + +@require_bitsandbytes +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class BaseMixedInt8Test(unittest.TestCase): + # We keep the constants inside the init function and model loading inside setUp function + + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only bloom-1b3 to test our module + model_name = "bigscience/bloom-1b7" + + # Constant values + EXPECTED_RELATIVE_DIFFERENCE = ( + 1.540025 # This was obtained on a Quadro RTX 8000 so the number might slightly change + ) + + input_text = "Hello my name is" + EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of the family.\n" + MAX_NEW_TOKENS = 10 + + def setUp(self): + # Models and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + +class MixedInt8Test(BaseMixedInt8Test): + def setUp(self): + super().setUp() + + # Models and tokenizer + self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + from bitsandbytes.nn import Int8Params + + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) + self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + + def test_generate_quality(self): + r""" + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +class MixedInt8ModelClassesTest(BaseMixedInt8Test): + def setUp(self): + super().setUp() + # model_name + self.model_name = "bigscience/bloom-560m" + # Models and tokenizer + self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + self.sequence_model = AutoModelForSequenceClassification.from_pretrained( + self.model_name, load_in_8bit=True, device_map="auto" + ) + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.base_model + del self.sequence_model + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_correct_head_class(self): + r""" + A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification) + are kept in their native class. + """ + from bitsandbytes.nn import Int8Params + + # last param of a base model should be a linear8bit module + self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + + # Other heads should be nn.Parameter + self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) + self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) + + +class MixedInt8TestPipeline(BaseMixedInt8Test): + def setUp(self): + super().setUp() + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.pipe + + gc.collect() + torch.cuda.empty_cache() + + def test_pipeline(self): + r""" + The aim of this test is to verify that the mixed int8 is compatible with `pipeline` from transformers. Since + we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything + on pipline. + """ + # self._clear_cuda_cache() + self.pipe = pipeline( + "text-generation", + model=self.model_name, + model_kwargs={"device_map": "auto", "load_in_8bit": True}, + max_new_tokens=self.MAX_NEW_TOKENS, + ) + + # Real second forward pass + pipeline_output = self.pipe(self.input_text) + self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT) + + +@require_torch_multi_gpu +class MixedInt8TestMultiGpu(BaseMixedInt8Test): + def setUp(self): + super().setUp() + + def test_multi_gpu_loading(self): + r""" + This tests that the model has been loaded and can be used correctly on a multi-GPU setup. + Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice + """ + + memory_mapping = {0: "1GB", 1: "2GB"} + model_parallel = AutoModelForCausalLM.from_pretrained( + self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" + ) + + def get_list_devices(model): + list_devices = [] + for _, module in model.named_children(): + if len(list(module.children())) > 0: + list_devices.extend(get_list_devices(module)) + else: + # Do a try except since we can encounter Dropout modules that does not + # have any device set + try: + list_devices.append(next(module.parameters()).device.index) + except BaseException: + continue + return list_devices + + list_devices = get_list_devices(model_parallel) + # Check that we have dispatched the model into 2 separate devices + self.assertTrue((1 in list_devices) and (0 in list_devices)) + + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # Second real batch + output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 329d248de3c089..ba122f43f805db 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -465,6 +465,7 @@ def module_to_test_file(module_fname): "tests/sagemaker/test_single_node_gpu.py", # SageMaker test "tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test "tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test + "tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test ]