From ed1dd12e1b6672454c5dce4f38a0bb687149e454 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 24 Jun 2022 09:21:39 +0000 Subject: [PATCH 01/82] first commit --- src/transformers/modeling_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9ee0378232efb2..15ce148e3e3e7c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -22,6 +22,7 @@ import tempfile import warnings from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass from functools import partial from pathlib import Path @@ -117,6 +118,16 @@ def __init__(self, *args, **kwargs): def forward(self, input): return input +def replace_8bit_linear(model): + quantized_model = deepcopy(model) + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + quantized_model[name] = Linear8bit(module.in_shape, module.out_shape, device=module.device) + else: + quantized_model[name] = module + return quantized_model + + def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): try: @@ -1776,6 +1787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) + load_in_8bit = kwargs.pop("load_in_8bit", False) if device_map is not None: if low_cpu_mem_usage is None: @@ -2061,12 +2073,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + elif load_in_8bit: + from bitsandbytes.nn import Linear8bitLt as Linear8bit + + init_contexts = [init_empty_weights()] # Force enable init empty weights + + logger.info("Detected 8-bit loading: activating 8-bit loading for this model") + # init_contexts.append() elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) + if load_in_8bit: + model = replace_8bit_linear(model) + if device_map == "auto": if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.") From b9d0da6aafc748f8de85c9b9f819a33ad5257461 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 24 Jun 2022 09:48:29 +0000 Subject: [PATCH 02/82] correct replace function --- src/transformers/modeling_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 15ce148e3e3e7c..7b88e86662ccf2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -118,14 +118,14 @@ def __init__(self, *args, **kwargs): def forward(self, input): return input -def replace_8bit_linear(model): - quantized_model = deepcopy(model) - for name, module in model.named_modules(): +def replace_8bit_linear(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module) + if isinstance(module, nn.Linear): - quantized_model[name] = Linear8bit(module.in_shape, module.out_shape, device=module.device) - else: - quantized_model[name] = module - return quantized_model + # setattr(model, n, bnb.nn.Linear8bitLt) + print(module) @@ -2074,7 +2074,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit: - from bitsandbytes.nn import Linear8bitLt as Linear8bit + # rom bitsandbytes.nn import Linear8bitLt as Linear8bit init_contexts = [init_empty_weights()] # Force enable init empty weights From dd9a464dc054e17b41b9f776092ac2b0ad619f4b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Jun 2022 18:25:30 +0000 Subject: [PATCH 03/82] add final changes - works like charm! - cannot implement tests yet - tested --- src/transformers/modeling_utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7b88e86662ccf2..a148644af435d6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -22,7 +22,6 @@ import tempfile import warnings from contextlib import contextmanager -from copy import deepcopy from dataclasses import dataclass from functools import partial from pathlib import Path @@ -118,15 +117,20 @@ def __init__(self, *args, **kwargs): def forward(self, input): return input + def replace_8bit_linear(model): + import bitsandbytes as bnb + for n, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module) - - if isinstance(module, nn.Linear): - # setattr(model, n, bnb.nn.Linear8bitLt) - print(module) + if isinstance(module, nn.Linear): + with init_empty_weights(): + model._modules[n] = bnb.nn.Linear8bitLt( + module.in_features, module.out_features, module.bias is not None, has_fp16_weights=True + ) + return model def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): @@ -2076,7 +2080,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif load_in_8bit: # rom bitsandbytes.nn import Linear8bitLt as Linear8bit - init_contexts = [init_empty_weights()] # Force enable init empty weights + init_contexts = [init_empty_weights()] # Force enable init empty weights logger.info("Detected 8-bit loading: activating 8-bit loading for this model") # init_contexts.append() @@ -2148,6 +2152,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, + load_in_8bit=load_in_8bit, ) # make sure token embedding weights are still tied if needed @@ -2187,6 +2192,7 @@ def _load_pretrained_model( offload_folder=None, offload_state_dict=False, dtype=None, + load_in_8bit=False, ): if device_map is not None and "disk" in device_map.values(): if offload_folder is None: @@ -2243,7 +2249,7 @@ def _fix_key(key): # retrieve weights on meta device and put them back on CPU. # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step - if low_cpu_mem_usage: + if low_cpu_mem_usage or (load_in_8bit and device_map is None): for key in missing_keys: if key.startswith(prefix): key = ".".join(key.split(".")[1:]) From 35e1534de74d7e5d61ac155142559218f8612fb2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 27 Jun 2022 18:39:54 +0000 Subject: [PATCH 04/82] clean up a bit --- src/transformers/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a148644af435d6..8768f145e6830c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2078,12 +2078,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit: - # rom bitsandbytes.nn import Linear8bitLt as Linear8bit init_contexts = [init_empty_weights()] # Force enable init empty weights logger.info("Detected 8-bit loading: activating 8-bit loading for this model") - # init_contexts.append() elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) @@ -2249,7 +2247,7 @@ def _fix_key(key): # retrieve weights on meta device and put them back on CPU. # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step - if low_cpu_mem_usage or (load_in_8bit and device_map is None): + if low_cpu_mem_usage: for key in missing_keys: if key.startswith(prefix): key = ".".join(key.split(".")[1:]) From d01822b8468a6dcfed010d65c7103ebcc479da44 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 17:55:21 +0000 Subject: [PATCH 05/82] add bitsandbytes dependencies --- src/transformers/utils/import_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index da42c73587f11a..02f987489d6e63 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -474,6 +474,8 @@ def is_protobuf_available(): def is_accelerate_available(): return importlib.util.find_spec("accelerate") is not None +def is_bitsandbytes_available(): + return importlib.util.find_spec("bitsandbytes") is not None def is_tokenizers_available(): return importlib.util.find_spec("tokenizers") is not None From 839c9cd4037d6f690fb1706ca1859ee81cc7c485 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 09:18:41 +0000 Subject: [PATCH 06/82] working version - added import function - added bitsandbytes utils file --- src/transformers/__init__.py | 1 + src/transformers/bitsandbytes_utils.py | 137 +++++++++++++++++++++++++ src/transformers/modeling_utils.py | 35 +++---- src/transformers/utils/import_utils.py | 2 - 4 files changed, 153 insertions(+), 22 deletions(-) create mode 100644 src/transformers/bitsandbytes_utils.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a68420b127ade4..5d4de04734bda7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -50,6 +50,7 @@ # Base objects, independent of any specific backend _import_structure = { "benchmark": [], + "bitsandbytes_utils": [], "commands": [], "configuration_utils": ["PretrainedConfig"], "convert_graph_to_onnx": [], diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/bitsandbytes_utils.py new file mode 100644 index 00000000000000..e1e1ea7501ca1e --- /dev/null +++ b/src/transformers/bitsandbytes_utils.py @@ -0,0 +1,137 @@ +import bitsandbytes as bnb +import torch +import torch.nn as nn +from contextlib import contextmanager + +from typing import Optional, Union + + +def set_module_8bit_tensor_to_device( + module: nn.Module, tensor_name: str, device: Union[int, str, torch.device], value: Optional[torch.Tensor] = None +): + """ + Args: + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + module (`torch.nn.Module`): The module in which the tensor we want to move lives. param_name (`str`): The full + name of the parameter/buffer. device (`int`, `str` or `torch.device`): The device on which to set the tensor. + value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any + other device). + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + if is_buffer: has_fp16_weights = None + else: has_fp16_weights = getattr(module._parameters[tensor_name], 'has_fp16_weights', None) + + if has_fp16_weights is not None: + param = module._parameters[tensor_name] + if param.device.type != 'cuda': + with torch.no_grad(): + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to('cpu') + else: + new_value = torch.tensor(value, device='cpu') + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) + module._parameters[tensor_name] = new_value + else: + with torch.no_grad(): + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + +@contextmanager +def init_empty_weights_8bit(include_buffers: bool = False): + """ + Args: + A context manager under which models are initialized with all parameters on the meta device, therefore creating an + empty model. Useful when just initializing the model would blow the available RAM. + include_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to also put all buffers on the meta device while initializing. + Example: + ```pyton + import torch.nn as nn + from accelerate import init_empty_weights + # Initialize a model with 100 billions parameters in no time and without using any RAM. + with init_empty_weights(): + tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) + ``` + + + + Any model created under this context manager has no weights. As such you can't do something like + `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. + + + """ + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + has_fp16_weights = getattr(param_cls, 'has_fp16_weights', None) + if has_fp16_weights is not None: + module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), has_fp16_weights=has_fp16_weights) + else: + module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta"))) + + def register_empty_buffer(module, name, buffer): + old_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_empty_parameter + if include_buffers: + nn.Module.register_buffer = register_empty_buffer + yield + finally: + nn.Module.register_parameter = old_register_parameter + if include_buffers: + nn.Module.register_buffer = old_register_buffer + + +def replace_8bit_linear(model): + import bitsandbytes as bnb + + for n, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module) + + if isinstance(module, nn.Linear) and n != "lm_head": + with init_empty_weights_8bit(): + model._modules[n] = bnb.nn.Linear8bitLt( + module.in_features, module.out_features, module.bias is not None, has_fp16_weights=False, threshold=6.0 + ) + return model \ No newline at end of file diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8768f145e6830c..0f5dc110814c89 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -65,6 +65,7 @@ has_file, hf_bucket_url, is_accelerate_available, + is_bitsandbytes_available, is_offline_mode, is_remote_url, logging, @@ -82,6 +83,9 @@ set_module_tensor_to_device, ) +if is_bitsandbytes_available(): + from .bitsandbytes_utils import init_empty_weights_8bit, replace_8bit_linear, set_module_8bit_tensor_to_device + logger = logging.get_logger(__name__) @@ -118,21 +122,6 @@ def forward(self, input): return input -def replace_8bit_linear(model): - import bitsandbytes as bnb - - for n, module in model.named_children(): - if len(list(module.children())) > 0: - replace_8bit_linear(module) - - if isinstance(module, nn.Linear): - with init_empty_weights(): - model._modules[n] = bnb.nn.Linear8bitLt( - module.in_features, module.out_features, module.bias is not None, has_fp16_weights=True - ) - return model - - def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -505,6 +494,7 @@ def _load_state_dict_into_meta_model( state_dict_folder=None, state_dict_index=None, dtype=None, + load_in_8bit=False, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -569,7 +559,10 @@ def _load_state_dict_into_meta_model( elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) else: - set_module_tensor_to_device(model, param_name, param_device, value=param) + if not load_in_8bit: + set_module_tensor_to_device(model, param_name, param_device, value=param) + else: + set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) return error_msgs, offload_index, state_dict_index @@ -2078,9 +2071,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit: - - init_contexts = [init_empty_weights()] # Force enable init empty weights - + init_contexts = [init_empty_weights_8bit()] # Force enable init empty weights logger.info("Detected 8-bit loading: activating 8-bit loading for this model") elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) @@ -2253,7 +2244,10 @@ def _fix_key(key): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] if param.device == torch.device("meta"): - set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) + if not load_in_8bit: + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) + else: + set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: @@ -2362,6 +2356,7 @@ def _find_mismatched_keys( state_dict_folder=state_dict_folder, state_dict_index=state_dict_index, dtype=dtype, + load_in_8bit=load_in_8bit, ) error_msgs += new_error_msgs else: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 02f987489d6e63..da42c73587f11a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -474,8 +474,6 @@ def is_protobuf_available(): def is_accelerate_available(): return importlib.util.find_spec("accelerate") is not None -def is_bitsandbytes_available(): - return importlib.util.find_spec("bitsandbytes") is not None def is_tokenizers_available(): return importlib.util.find_spec("tokenizers") is not None From 42a684574cdf32456605718f239a47c27f13c347 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 09:48:25 +0000 Subject: [PATCH 07/82] small fix --- src/transformers/bitsandbytes_utils.py | 41 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/bitsandbytes_utils.py index e1e1ea7501ca1e..fd34629a88a9a4 100644 --- a/src/transformers/bitsandbytes_utils.py +++ b/src/transformers/bitsandbytes_utils.py @@ -1,9 +1,14 @@ -import bitsandbytes as bnb +from contextlib import contextmanager +from typing import Optional, Union + import torch import torch.nn as nn -from contextlib import contextmanager -from typing import Optional, Union +from .utils import is_bitsandbytes_available + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb def set_module_8bit_tensor_to_device( @@ -38,20 +43,24 @@ class `Int8Params` from `bitsandbytes`. if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") - if is_buffer: has_fp16_weights = None - else: has_fp16_weights = getattr(module._parameters[tensor_name], 'has_fp16_weights', None) + if is_buffer: + has_fp16_weights = None + else: + has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None) if has_fp16_weights is not None: param = module._parameters[tensor_name] - if param.device.type != 'cuda': + if param.device.type != "cuda": with torch.no_grad(): if value is None: new_value = old_value.to(device) elif isinstance(value, torch.Tensor): - new_value = value.to('cpu') + new_value = value.to("cpu") else: - new_value = torch.tensor(value, device='cpu') - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) + new_value = torch.tensor(value, device="cpu") + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to( + device + ) module._parameters[tensor_name] = new_value else: with torch.no_grad(): @@ -100,9 +109,11 @@ def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) - has_fp16_weights = getattr(param_cls, 'has_fp16_weights', None) + has_fp16_weights = getattr(param_cls, "has_fp16_weights", None) if has_fp16_weights is not None: - module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), has_fp16_weights=has_fp16_weights) + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), has_fp16_weights=has_fp16_weights + ) else: module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta"))) @@ -132,6 +143,10 @@ def replace_8bit_linear(model): if isinstance(module, nn.Linear) and n != "lm_head": with init_empty_weights_8bit(): model._modules[n] = bnb.nn.Linear8bitLt( - module.in_features, module.out_features, module.bias is not None, has_fp16_weights=False, threshold=6.0 + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=6.0, ) - return model \ No newline at end of file + return model From 97f64f8d0cc9201d1d6de4c0fc9bcdf03e388cae Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 10:07:19 +0000 Subject: [PATCH 08/82] small fix - fix import issue --- src/transformers/bitsandbytes_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/bitsandbytes_utils.py index fd34629a88a9a4..8333562d2c4e84 100644 --- a/src/transformers/bitsandbytes_utils.py +++ b/src/transformers/bitsandbytes_utils.py @@ -1,13 +1,13 @@ from contextlib import contextmanager from typing import Optional, Union -import torch -import torch.nn as nn - from .utils import is_bitsandbytes_available if is_bitsandbytes_available(): + import torch + import torch.nn as nn + import bitsandbytes as bnb From a1fe7fcc47561aff297eb1c3ce4dd3d3a0357183 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 12:08:06 +0000 Subject: [PATCH 09/82] fix import issues --- src/transformers/bitsandbytes_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/bitsandbytes_utils.py index 8333562d2c4e84..7cdfe58db4ca6f 100644 --- a/src/transformers/bitsandbytes_utils.py +++ b/src/transformers/bitsandbytes_utils.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from typing import Optional, Union from .utils import is_bitsandbytes_available @@ -11,9 +10,7 @@ import bitsandbytes as bnb -def set_module_8bit_tensor_to_device( - module: nn.Module, tensor_name: str, device: Union[int, str, torch.device], value: Optional[torch.Tensor] = None -): +def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): """ Args: A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing @@ -78,7 +75,7 @@ class `Int8Params` from `bitsandbytes`. @contextmanager -def init_empty_weights_8bit(include_buffers: bool = False): +def init_empty_weights_8bit(include_buffers=False): """ Args: A context manager under which models are initialized with all parameters on the meta device, therefore creating an From 05739e38940540f96e655accfa77a7b4dd53049b Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 12 Jul 2022 15:25:28 +0200 Subject: [PATCH 10/82] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/bitsandbytes_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/bitsandbytes_utils.py index 7cdfe58db4ca6f..1e45b381b40e80 100644 --- a/src/transformers/bitsandbytes_utils.py +++ b/src/transformers/bitsandbytes_utils.py @@ -82,7 +82,9 @@ def init_empty_weights_8bit(include_buffers=False): empty model. Useful when just initializing the model would blow the available RAM. include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. + Example: + ```pyton import torch.nn as nn from accelerate import init_empty_weights From 7816ef919bab246ce09f45e87655a6e2830f5482 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 13:51:14 +0000 Subject: [PATCH 11/82] refactor a bit - move bitsandbytes utils to utils - change comments on functions --- src/transformers/__init__.py | 4 +-- src/transformers/modeling_utils.py | 13 ++++++---- .../{ => utils}/bitsandbytes_utils.py | 25 ++++++++++++------- 3 files changed, 26 insertions(+), 16 deletions(-) rename src/transformers/{ => utils}/bitsandbytes_utils.py (88%) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f476a7ca1f6ea8..da0e4ea52bdcaa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -51,7 +51,6 @@ # Base objects, independent of any specific backend _import_structure = { "benchmark": [], - "bitsandbytes_utils": [], "commands": [], "configuration_utils": ["PretrainedConfig"], "convert_graph_to_onnx": [], @@ -119,8 +118,8 @@ "load_tf2_weights_in_pytorch_model", ], "models": [], - # Models "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], + # Models "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", @@ -454,6 +453,7 @@ "is_vision_available", "logging", ], + "utils.bitsandbytes_utils": [], } # sentencepiece-backed objects diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3fc98c37f32f79..b2a1a74d28b432 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -84,7 +84,11 @@ ) if is_bitsandbytes_available(): - from .bitsandbytes_utils import init_empty_weights_8bit, replace_8bit_linear, set_module_8bit_tensor_to_device + from .utils.bitsandbytes_utils import ( + init_empty_weights_8bit, + replace_8bit_linear, + set_module_8bit_tensor_to_device, + ) logger = logging.get_logger(__name__) @@ -558,11 +562,10 @@ def _load_state_dict_into_meta_model( offload_index = offload_weight(param, param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif not load_in_8bit: + set_module_tensor_to_device(model, param_name, param_device, value=param) else: - if not load_in_8bit: - set_module_tensor_to_device(model, param_name, param_device, value=param) - else: - set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) + set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) return error_msgs, offload_index, state_dict_index diff --git a/src/transformers/bitsandbytes_utils.py b/src/transformers/utils/bitsandbytes_utils.py similarity index 88% rename from src/transformers/bitsandbytes_utils.py rename to src/transformers/utils/bitsandbytes_utils.py index 7cdfe58db4ca6f..b2e1fbff93a852 100644 --- a/src/transformers/bitsandbytes_utils.py +++ b/src/transformers/utils/bitsandbytes_utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from .utils import is_bitsandbytes_available +from transformers.utils import is_bitsandbytes_available if is_bitsandbytes_available(): @@ -12,15 +12,20 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): """ - Args: A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the class `Int8Params` from `bitsandbytes`. - module (`torch.nn.Module`): The module in which the tensor we want to move lives. param_name (`str`): The full - name of the parameter/buffer. device (`int`, `str` or `torch.device`): The device on which to set the tensor. - value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any - other device). + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): The full + name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). """ # Recurse if needed if "." in tensor_name: @@ -79,15 +84,17 @@ def init_empty_weights_8bit(include_buffers=False): """ Args: A context manager under which models are initialized with all parameters on the meta device, therefore creating an - empty model. Useful when just initializing the model would blow the available RAM. + empty model. Useful when just initializing the model would blow the available RAM. This function is adapted + from the function `init_empty_weights` of accelerate to make it work on `Linear8bitLt` modules from `bitsandbytes` include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. Example: ```pyton import torch.nn as nn - from accelerate import init_empty_weights + from transformers.utils.bitsandbytes_utils import init_empty_weights_8bit + # Initialize a model with 100 billions parameters in no time and without using any RAM. - with init_empty_weights(): + with init_empty_weights_8bit(): tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) ``` From b222b9a50c88a3253bcd5ce35dd4a83dabb2f3e5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 13:54:41 +0000 Subject: [PATCH 12/82] reformat docstring - reformat docstring on init_empty_weights_8bit --- src/transformers/utils/bitsandbytes_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/utils/bitsandbytes_utils.py b/src/transformers/utils/bitsandbytes_utils.py index 83499a2c2b46e2..2f70acbbf859ec 100644 --- a/src/transformers/utils/bitsandbytes_utils.py +++ b/src/transformers/utils/bitsandbytes_utils.py @@ -18,13 +18,13 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): class `Int8Params` from `bitsandbytes`. Args: - module (`torch.nn.Module`): + module (`torch.nn.Module`): The module in which the tensor we want to move lives. tensor_name (`str`): The full name of the parameter/buffer. - device (`int`, `str` or `torch.device`): + device (`int`, `str` or `torch.device`): The device on which to set the tensor. - value (`torch.Tensor`, *optional*): + value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any other device). """ # Recurse if needed @@ -84,8 +84,8 @@ def init_empty_weights_8bit(include_buffers=False): """ Args: A context manager under which models are initialized with all parameters on the meta device, therefore creating an - empty model. Useful when just initializing the model would blow the available RAM. This function is adapted - from the function `init_empty_weights` of accelerate to make it work on `Linear8bitLt` modules from `bitsandbytes` + empty model. Useful when just initializing the model would blow the available RAM. This function is adapted from + the function `init_empty_weights` of accelerate to make it work on `Linear8bitLt` modules from `bitsandbytes` include_buffers (`bool`, *optional*, defaults to `False`): Whether or not to also put all buffers on the meta device while initializing. @@ -94,7 +94,7 @@ def init_empty_weights_8bit(include_buffers=False): ```pyton import torch.nn as nn from transformers.utils.bitsandbytes_utils import init_empty_weights_8bit - + # Initialize a model with 100 billions parameters in no time and without using any RAM. with init_empty_weights_8bit(): tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) From 32f48cd88f970379180b256be9a20cf5132d5970 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 12 Jul 2022 16:36:23 +0200 Subject: [PATCH 13/82] Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da0e4ea52bdcaa..3ac01fc7d575a5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -453,7 +453,7 @@ "is_vision_available", "logging", ], - "utils.bitsandbytes_utils": [], + "utils.bitsandbytes": [], } # sentencepiece-backed objects From e116e21b4d708c8a4bfb70566e0c7ec14416dc48 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 14:40:19 +0000 Subject: [PATCH 14/82] revert bad formatting --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3ac01fc7d575a5..858706b7283cc7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -118,8 +118,8 @@ "load_tf2_weights_in_pytorch_model", ], "models": [], - "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], # Models + "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", From 39c46a07648fa5565117c6ac33bfcbe69449f2c0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 14:45:59 +0000 Subject: [PATCH 15/82] change to bitsandbytes --- src/transformers/modeling_utils.py | 6 +----- .../utils/{bitsandbytes_utils.py => bitsandbytes.py} | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) rename src/transformers/utils/{bitsandbytes_utils.py => bitsandbytes.py} (98%) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b2a1a74d28b432..93af99764c0ad1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -84,11 +84,7 @@ ) if is_bitsandbytes_available(): - from .utils.bitsandbytes_utils import ( - init_empty_weights_8bit, - replace_8bit_linear, - set_module_8bit_tensor_to_device, - ) + from .utils.bitsandbytes import init_empty_weights_8bit, replace_8bit_linear, set_module_8bit_tensor_to_device logger = logging.get_logger(__name__) diff --git a/src/transformers/utils/bitsandbytes_utils.py b/src/transformers/utils/bitsandbytes.py similarity index 98% rename from src/transformers/utils/bitsandbytes_utils.py rename to src/transformers/utils/bitsandbytes.py index 2f70acbbf859ec..42fe71e895fca1 100644 --- a/src/transformers/utils/bitsandbytes_utils.py +++ b/src/transformers/utils/bitsandbytes.py @@ -93,7 +93,7 @@ def init_empty_weights_8bit(include_buffers=False): ```pyton import torch.nn as nn - from transformers.utils.bitsandbytes_utils import init_empty_weights_8bit + from transformers.utils.bitsandbytes import init_empty_weights_8bit # Initialize a model with 100 billions parameters in no time and without using any RAM. with init_empty_weights_8bit(): From b92c25c3835f37667ee033a2b406846da06f16f4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Jul 2022 15:53:00 +0000 Subject: [PATCH 16/82] refactor a bit - remove init8bit since it is useless --- src/transformers/modeling_utils.py | 11 +++- src/transformers/utils/bitsandbytes.py | 70 ++------------------------ 2 files changed, 14 insertions(+), 67 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 93af99764c0ad1..ebcb157a58396c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -84,7 +84,7 @@ ) if is_bitsandbytes_available(): - from .utils.bitsandbytes import init_empty_weights_8bit, replace_8bit_linear, set_module_8bit_tensor_to_device + from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device logger = logging.get_logger(__name__) @@ -1804,6 +1804,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`" ) + if load_in_8bit: + if not (is_accelerate_available() and is_bitsandbytes_available()): + raise ImportError( + "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" + " bitsandbytes `pip install bitsandbytes`" + ) + from_pt = not (from_tf | from_flax) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -2072,7 +2079,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit: - init_contexts = [init_empty_weights_8bit()] # Force enable init empty weights + init_contexts = [init_empty_weights()] # Force enable init empty weights logger.info("Detected 8-bit loading: activating 8-bit loading for this model") elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 42fe71e895fca1..cab643806afd11 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,6 +1,4 @@ -from contextlib import contextmanager - -from transformers.utils import is_bitsandbytes_available +from transformers.utils import is_accelerate_available, is_bitsandbytes_available if is_bitsandbytes_available(): @@ -9,6 +7,9 @@ import bitsandbytes as bnb +if is_accelerate_available(): + from accelerate import init_empty_weights + def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): """ @@ -79,75 +80,14 @@ class `Int8Params` from `bitsandbytes`. module._parameters[tensor_name] = new_value -@contextmanager -def init_empty_weights_8bit(include_buffers=False): - """ - Args: - A context manager under which models are initialized with all parameters on the meta device, therefore creating an - empty model. Useful when just initializing the model would blow the available RAM. This function is adapted from - the function `init_empty_weights` of accelerate to make it work on `Linear8bitLt` modules from `bitsandbytes` - include_buffers (`bool`, *optional*, defaults to `False`): - Whether or not to also put all buffers on the meta device while initializing. - - Example: - - ```pyton - import torch.nn as nn - from transformers.utils.bitsandbytes import init_empty_weights_8bit - - # Initialize a model with 100 billions parameters in no time and without using any RAM. - with init_empty_weights_8bit(): - tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) - ``` - - - - Any model created under this context manager has no weights. As such you can't do something like - `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. - - - """ - old_register_parameter = nn.Module.register_parameter - if include_buffers: - old_register_buffer = nn.Module.register_buffer - - def register_empty_parameter(module, name, param): - old_register_parameter(module, name, param) - if param is not None: - param_cls = type(module._parameters[name]) - has_fp16_weights = getattr(param_cls, "has_fp16_weights", None) - if has_fp16_weights is not None: - module._parameters[name] = param_cls( - module._parameters[name].to(torch.device("meta")), has_fp16_weights=has_fp16_weights - ) - else: - module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta"))) - - def register_empty_buffer(module, name, buffer): - old_register_buffer(module, name, buffer) - if buffer is not None: - module._buffers[name] = module._buffers[name].to(torch.device("meta")) - - try: - nn.Module.register_parameter = register_empty_parameter - if include_buffers: - nn.Module.register_buffer = register_empty_buffer - yield - finally: - nn.Module.register_parameter = old_register_parameter - if include_buffers: - nn.Module.register_buffer = old_register_buffer - - def replace_8bit_linear(model): - import bitsandbytes as bnb for n, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module) if isinstance(module, nn.Linear) and n != "lm_head": - with init_empty_weights_8bit(): + with init_empty_weights(): model._modules[n] = bnb.nn.Linear8bitLt( module.in_features, module.out_features, From 3779f5d88eeed0a9ddbfb9a08947819f0df224fb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 13 Jul 2022 16:31:12 +0000 Subject: [PATCH 17/82] more refactoring - fixed init empty weights issue - added threshold param --- src/transformers/modeling_utils.py | 6 ++++-- src/transformers/utils/bitsandbytes.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ebcb157a58396c..791dafdf901f7e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -553,7 +553,6 @@ def _load_state_dict_into_meta_model( # TODO: group all errors and raise at the end. raise ValueError(f"{param_name} doesn't have any device set.") param_device = device_map[module_name] - if param_device == "disk": offload_index = offload_weight(param, param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: @@ -1784,6 +1783,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) + int8_threshold = kwargs.pop("int8_threshold", None) if device_map is not None: if low_cpu_mem_usage is None: @@ -1810,6 +1810,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" " bitsandbytes `pip install bitsandbytes`" ) + if not int8_threshold: + int8_threshold = 6.0 from_pt = not (from_tf | from_flax) @@ -2088,7 +2090,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls(config, *model_args, **model_kwargs) if load_in_8bit: - model = replace_8bit_linear(model) + model = replace_8bit_linear(model, threshold=int8_threshold) if device_map == "auto": if model._no_split_modules is None: diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index cab643806afd11..84430c992a00f3 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -80,11 +80,11 @@ class `Int8Params` from `bitsandbytes`. module._parameters[tensor_name] = new_value -def replace_8bit_linear(model): +def replace_8bit_linear(model, threshold=6.0): for n, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module) + replace_8bit_linear(module, threshold) if isinstance(module, nn.Linear) and n != "lm_head": with init_empty_weights(): @@ -93,6 +93,6 @@ def replace_8bit_linear(model): module.out_features, module.bias is not None, has_fp16_weights=False, - threshold=6.0, + threshold=threshold, ) return model From b41c2509db5e934214d01a465075e78d9c72dede Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Jul 2022 09:30:35 +0000 Subject: [PATCH 18/82] small hack to make it work --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 791dafdf901f7e..3e26aa08a6f15e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2163,7 +2163,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.eval() # Dispatch model with hooks on all devices if necessary - if device_map is not None: + if device_map is not None and not load_in_8bit: dispatch_model(model, device_map=device_map, offload_dir=offload_folder) if output_loading_info: From c91a58e5c5e5e462ef155693db4976c4f86d3b6f Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 20 Jul 2022 13:31:23 +0200 Subject: [PATCH 19/82] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 94c52e12ae3f5b..898f7b98a252a0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1788,7 +1788,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) int8_threshold = kwargs.pop("int8_threshold", None) - +subfolder = kwargs.pop("subfolder", "") if device_map is not None: if low_cpu_mem_usage is None: low_cpu_mem_usage = True From db16cf88fe1b30899a488e870604b8e496ce59dd Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 20 Jul 2022 13:31:42 +0200 Subject: [PATCH 20/82] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 898f7b98a252a0..68a48ee6d3085a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1788,7 +1788,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) int8_threshold = kwargs.pop("int8_threshold", None) -subfolder = kwargs.pop("subfolder", "") + subfolder = kwargs.pop("subfolder", "") if device_map is not None: if low_cpu_mem_usage is None: low_cpu_mem_usage = True From be6ce291d0a4214100aa5b19b8aa88cf26318926 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 21 Jul 2022 16:13:27 +0000 Subject: [PATCH 21/82] revmoe the small hack --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e26aa08a6f15e..343c57b09be638 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -75,6 +75,7 @@ if is_accelerate_available(): + import accelerate from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate.utils import ( load_offloaded_weights, @@ -84,7 +85,7 @@ ) if is_bitsandbytes_available(): - from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device + from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device, replace_set_tensor_function logger = logging.get_logger(__name__) @@ -2163,9 +2164,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.eval() # Dispatch model with hooks on all devices if necessary - if device_map is not None and not load_in_8bit: + if device_map: dispatch_model(model, device_map=device_map, offload_dir=offload_folder) + if output_loading_info: if loading_info is None: loading_info = { From 15a81e0ef1fa575edb98e2fbbb7884c83bfdf80b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 21 Jul 2022 16:14:40 +0000 Subject: [PATCH 22/82] modify utils file --- src/transformers/utils/bitsandbytes.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 84430c992a00f3..8252ae987586c1 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -8,7 +9,9 @@ import bitsandbytes as bnb if is_accelerate_available(): + import accelerate from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): @@ -96,3 +99,8 @@ def replace_8bit_linear(model, threshold=6.0): threshold=threshold, ) return model + +@contextmanager +def replace_set_tensor_function(): + setattr(accelerate.utils, "set_module_tensor_to_device", set_module_8bit_tensor_to_device) + yield \ No newline at end of file From 514758dd6d39202ca45c652a28e34f4adf893dd8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 22 Jul 2022 14:49:42 +0000 Subject: [PATCH 23/82] make style + refactor a bit --- src/transformers/modeling_utils.py | 4 +--- src/transformers/utils/bitsandbytes.py | 8 -------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4bb178ca1024fb..3d57ce77792fe1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -75,7 +75,6 @@ if is_accelerate_available(): - import accelerate from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate.utils import ( load_offloaded_weights, @@ -85,7 +84,7 @@ ) if is_bitsandbytes_available(): - from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device, replace_set_tensor_function + from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device logger = logging.get_logger(__name__) @@ -2186,7 +2185,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map: dispatch_model(model, device_map=device_map, offload_dir=offload_folder) - if output_loading_info: if loading_info is None: loading_info = { diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 8252ae987586c1..84430c992a00f3 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -9,9 +8,7 @@ import bitsandbytes as bnb if is_accelerate_available(): - import accelerate from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): @@ -99,8 +96,3 @@ def replace_8bit_linear(model, threshold=6.0): threshold=threshold, ) return model - -@contextmanager -def replace_set_tensor_function(): - setattr(accelerate.utils, "set_module_tensor_to_device", set_module_8bit_tensor_to_device) - yield \ No newline at end of file From a09e0556d1b9ddd713fd60447a624147abf9c7a0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 25 Jul 2022 10:38:06 +0000 Subject: [PATCH 24/82] create correctly device map --- src/transformers/modeling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3d57ce77792fe1..c3ca280260cda2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2225,6 +2225,7 @@ def _load_pretrained_model( if offload_state_dict is None: offload_state_dict = True + print(device_map) # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) From 387aa1eb1987ff104f9e1137733c15fcc68ec45b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 25 Jul 2022 10:56:54 +0000 Subject: [PATCH 25/82] add correct dtype for device map creation --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c3ca280260cda2..d3a1b20cf56e64 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2118,7 +2118,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Make sure tied weights are tied before creating the device map. model.tie_weights() device_map = infer_auto_device_map( - model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory + model, + no_split_module_classes=no_split_modules, + dtype=torch_dtype if not load_in_8bit else torch.int8, + max_memory=max_memory, ) if from_tf: @@ -2225,7 +2228,6 @@ def _load_pretrained_model( if offload_state_dict is None: offload_state_dict = True - print(device_map) # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) From 7199292efee9fe056551f29226420a613dbd4226 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 27 Jul 2022 11:05:09 +0200 Subject: [PATCH 26/82] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d3a1b20cf56e64..8022f5ca9c5235 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1787,7 +1787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) load_in_8bit = kwargs.pop("load_in_8bit", False) - int8_threshold = kwargs.pop("int8_threshold", None) + int8_threshold = kwargs.pop("int8_threshold", 6.0) subfolder = kwargs.pop("subfolder", "") if device_map is not None: if low_cpu_mem_usage is None: @@ -1814,8 +1814,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" " bitsandbytes `pip install bitsandbytes`" ) - if not int8_threshold: - int8_threshold = 6.0 from_pt = not (from_tf | from_flax) From a4c19c193b598efae80f2489298c3aeeacb50804 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 27 Jul 2022 09:11:03 +0000 Subject: [PATCH 27/82] apply suggestions - remove with torch.grad - do not rely on Python bool magic! --- src/transformers/modeling_utils.py | 2 +- src/transformers/utils/bitsandbytes.py | 27 ++++++++++++-------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8022f5ca9c5235..5284d3c9281ed9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2183,7 +2183,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.eval() # Dispatch model with hooks on all devices if necessary - if device_map: + if device_map is not None: dispatch_model(model, device_map=device_map, offload_dir=offload_folder) if output_loading_info: diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 84430c992a00f3..cb237025c0012a 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -54,25 +54,22 @@ class `Int8Params` from `bitsandbytes`. if has_fp16_weights is not None: param = module._parameters[tensor_name] if param.device.type != "cuda": - with torch.no_grad(): - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to("cpu") - else: - new_value = torch.tensor(value, device="cpu") - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to( - device - ) - module._parameters[tensor_name] = new_value - else: - with torch.no_grad(): if value is None: new_value = old_value.to(device) elif isinstance(value, torch.Tensor): - new_value = value.to(device) + new_value = value.to("cpu") else: - new_value = torch.tensor(value, device=device) + new_value = torch.tensor(value, device="cpu") + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) + module._parameters[tensor_name] = new_value + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + if is_buffer: module._buffers[tensor_name] = new_value else: From a5cd157b03311d99661c27ea72e6348406ffdccd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 27 Jul 2022 09:20:00 +0000 Subject: [PATCH 28/82] add docstring - add docstring for new kwargs --- src/transformers/modeling_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5284d3c9281ed9..4d31b8d2ec181a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1697,6 +1697,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. + load_in_8bit (`bool`, *optional*): + If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please + install `bitsandbytes` compiled with your CUDA version by running `pip install -i + https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). + Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are + not compiled and adapted for CPUs. + int8_threshold (`int`, *optional*): + Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as + described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper (#TODO provide + link to paper). Set to `6.0` by default as described in the paper. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. From 9bd326b7eaf758e19f12166ec0bbba1906d53062 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 27 Jul 2022 09:29:00 +0000 Subject: [PATCH 29/82] add docstring - comment `replace_8bit_linear` function - fix weird formatting --- src/transformers/utils/bitsandbytes.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index cb237025c0012a..a7d1c5008ef1d6 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -21,8 +21,8 @@ class `Int8Params` from `bitsandbytes`. Args: module (`torch.nn.Module`): The module in which the tensor we want to move lives. - tensor_name (`str`): The full - name of the parameter/buffer. + tensor_name (`str`): + The full name of the parameter/buffer. device (`int`, `str` or `torch.device`): The device on which to set the tensor. value (`torch.Tensor`, *optional*): @@ -78,7 +78,24 @@ class `Int8Params` from `bitsandbytes`. def replace_8bit_linear(model, threshold=6.0): - + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + threshold (`float`, *optional*): + `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to + `6.0` as described by the paper. + """ for n, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module, threshold) From 6e4fee622ee15440ffc3484ba051dbc5e983408c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 13:24:24 +0000 Subject: [PATCH 30/82] - added more documentation - added new utility function for memory footprint tracking - colab demo to add --- docs/source/en/main_classes/model.mdx | 24 +++++++++++++++++++++++- src/transformers/modeling_utils.py | 18 ++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index c59af2d2214814..d41a11360a805b 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1} ``` -Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`). +Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below. ### Model Instantiation dtype @@ -133,6 +133,28 @@ model = AutoModel.from_config(config) Due to Pytorch design, this functionality is only available for floating dtypes. +### Mixed int8 quantization đŸĒļ + +From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. +For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization. +This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports onyl PyTorch models. +Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature. + +#### Requirements + +- Install the correct version of `bitsandbytes` by running: +`pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding to your `CUDA` version. E.g, if you have `CUDA` 11.3 installed on your machine, run `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113` +- Install `accelerate`: +`pip install accelerate` + +#### Running mixed-int8 models + +After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows: +``` +model_name = "bigscience/bloom-2b5" +model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16) +``` +The implementation supports multi-GPU setup thanks to `accelerate` as backend. ## ModuleUtilsMixin diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4d31b8d2ec181a..6ab4caabe531be 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1570,6 +1570,24 @@ def save_pretrained( url = self._push_to_hub(repo, commit_message=commit_message) logger.info(f"Model pushed to the hub in this commit: {url}") + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): r""" From 9e121b73b22134763daa0b880d0a2a225fc949f5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 13:45:38 +0000 Subject: [PATCH 31/82] few modifs - typo doc - force cast into float16 when load_in_8bit is enabled --- docs/source/en/main_classes/model.mdx | 4 ++-- src/transformers/modeling_utils.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index d41a11360a805b..aa772bd0f43752 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -137,7 +137,7 @@ Due to Pytorch design, this functionality is only available for floating dtypes. From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization. -This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports onyl PyTorch models. +This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature. #### Requirements @@ -152,7 +152,7 @@ Note also that you would require a GPU to run mixed-8bit models as the kernels h After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows: ``` model_name = "bigscience/bloom-2b5" -model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16) +model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) ``` The implementation supports multi-GPU setup thanks to `accelerate` as backend. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ab4caabe531be..479ad33505d58d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1842,6 +1842,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" " bitsandbytes `pip install bitsandbytes`" ) + if torch_dtype == "auto" or torch_dtype != torch.float16: + torch_dtype = ( + torch.float16 + ) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") from_pt = not (from_tf | from_flax) From a2ac688f1325d0fb3f695f6f3b2107d43da1f756 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 14:01:20 +0000 Subject: [PATCH 32/82] added colab link --- docs/source/en/main_classes/model.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index aa772bd0f43752..741368d8e0d7ee 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -140,6 +140,8 @@ For models trained in half-precision (aka, either `float16` or `bfloat16`). Thi This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature. +Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing) + #### Requirements - Install the correct version of `bitsandbytes` by running: From ac370b9893f011d4af62dd66a552423de7a3e180 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 18:06:12 +0000 Subject: [PATCH 33/82] add test architecture + docstring a bit --- src/transformers/modeling_utils.py | 3 +- tests/mixed_int8/__init__.py | 0 tests/mixed_int8/test_mixed_int8.py | 104 ++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 tests/mixed_int8/__init__.py create mode 100644 tests/mixed_int8/test_mixed_int8.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 479ad33505d58d..4fea35161acada 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1840,7 +1840,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if not (is_accelerate_available() and is_bitsandbytes_available()): raise ImportError( "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" - " bitsandbytes `pip install bitsandbytes`" + " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding" + "to your CUDA version (e.g. `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113` for CUDA 11.3)" ) if torch_dtype == "auto" or torch_dtype != torch.float16: torch_dtype = ( diff --git a/tests/mixed_int8/__init__.py b/tests/mixed_int8/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py new file mode 100644 index 00000000000000..53c9ae6141e86b --- /dev/null +++ b/tests/mixed_int8/test_mixed_int8.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers.testing_utils import ( + require_torch, + slow, + require_bitsandbytes, + require_accelerate, + require_torch_gpu, + require_torch_multi_gpu, +) + + +@require_bitsandbytes +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class MixedInt8Test(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only bloom-1b3 to test our module + model_name = "bigscience/bloom-1b3" + + # Constant values + EXPECTED_RELATIVE_DIFFERENCE = ( + 1.540025 # This was obtained on a Quadro RTX 8000 so the number might slightly change + ) + input_text = "Hello my name is" + EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of your father.\n" + MAX_NEW_TOKENS = 10 + + # Models pipeline and tokenizer + model_fp16 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") + model_8bit = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_name) + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + from bitsandbytes.nn import Int8Params + + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) + self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + + def test_generate_quality(self): + r""" + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) + + self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_pipeline(self): + r""" + The aim of this test is to verify that the mixed int8 is compatible with `pipeline` from transformers. Since + we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything + on pipline. + """ + + pipe = pipeline( + model=self.model_name, + model_kwargs={"device_map": "auto", "load_in_8bit": True}, + max_new_tokens=self.MAX_NEW_TOKENS, + ) + pipeline_output = pipe(self.input_text) + self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT) + + def test_save_load(self): + r""" + The aim of this test is to verify whether if we save and load back a quantized model we retain the same performance. + If this test pass people can safely push quantized models on the Hub. + """ + pass + + @require_torch_multi_gpu + def test_multi_gpu_loading(self): + r""" + This tests that the model has been loaded and can be used correctly on a multi-GPU setup + """ + pass From 27e94864e4bfb07eebbd4f84f76ee5381e4d0078 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 18:11:47 +0000 Subject: [PATCH 34/82] refactor a bit testing class --- tests/mixed_int8/test_mixed_int8.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 53c9ae6141e86b..c0e1c4fe6ca4cc 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -32,6 +32,8 @@ @require_torch_gpu @slow class MixedInt8Test(unittest.TestCase): + # We keep the constants inside the init function and model loading inside setUp function + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) # Therefore here we use only bloom-1b3 to test our module model_name = "bigscience/bloom-1b3" @@ -40,14 +42,16 @@ class MixedInt8Test(unittest.TestCase): EXPECTED_RELATIVE_DIFFERENCE = ( 1.540025 # This was obtained on a Quadro RTX 8000 so the number might slightly change ) + input_text = "Hello my name is" EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of your father.\n" MAX_NEW_TOKENS = 10 - - # Models pipeline and tokenizer - model_fp16 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") - model_8bit = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto") - tokenizer = AutoTokenizer.from_pretrained(model_name) + def setUp(self): + # Models pipeline and tokenizer + self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + return super().setUp() def test_memory_footprint(self): r""" From a0db9828de665bdff6db78f3864c51e1f4f7c47f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Jul 2022 22:05:57 +0200 Subject: [PATCH 35/82] make style + refactor a bit --- src/transformers/modeling_utils.py | 5 +++-- tests/mixed_int8/test_mixed_int8.py | 7 ++++--- utils/tests_fetcher.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4fea35161acada..3d98f0e4468349 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1840,8 +1840,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if not (is_accelerate_available() and is_bitsandbytes_available()): raise ImportError( "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" - " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding" - "to your CUDA version (e.g. `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113` for CUDA 11.3)" + " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX`" + " correspondingto your CUDA version (e.g. `pip install -i https://test.pypi.org/simple/" + " bitsandbytes-cuda113` for CUDA 11.3)" ) if torch_dtype == "auto" or torch_dtype != torch.float16: torch_dtype = ( diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index c0e1c4fe6ca4cc..fbbf0cbbc9a0fa 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -17,12 +17,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( - require_torch, - slow, - require_bitsandbytes, require_accelerate, + require_bitsandbytes, + require_torch, require_torch_gpu, require_torch_multi_gpu, + slow, ) @@ -46,6 +46,7 @@ class MixedInt8Test(unittest.TestCase): input_text = "Hello my name is" EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of your father.\n" MAX_NEW_TOKENS = 10 + def setUp(self): # Models pipeline and tokenizer self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 9f18bb83c7ee7f..2b33945aa37686 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -464,6 +464,7 @@ def module_to_test_file(module_fname): "tests/sagemaker/test_single_node_gpu.py", # SageMaker test "tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test "tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test + "tests/mixed_int8/test_mixed_int8.py", # SageMaker test ] From 21bd590862ac9688ad783191aeb23fff177119bf Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 09:10:48 +0000 Subject: [PATCH 36/82] enhance checks - add more checks - start writing saving test --- src/transformers/modeling_utils.py | 10 ++++++++++ tests/mixed_int8/test_mixed_int8.py | 16 +++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4fea35161acada..ff608984ea47f7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1848,6 +1848,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch.float16 ) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") + if device_map is None: + raise ValueError( + "A device map needs to be passed to run convert models into mixed-int8 format. Please run" + "`.from_pretrained` with `device_map='auto'`" + ) + if (from_tf | from_flax): + raise ValueError( + "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make sure" + "the weights are in PyTorch format." + ) from_pt = not (from_tf | from_flax) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index c0e1c4fe6ca4cc..24aaeef0104641 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( - require_torch, - slow, - require_bitsandbytes, require_accelerate, + require_bitsandbytes, + require_torch, require_torch_gpu, require_torch_multi_gpu, + slow, ) @@ -30,7 +31,7 @@ @require_accelerate @require_torch @require_torch_gpu -@slow +# @slow class MixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -46,6 +47,7 @@ class MixedInt8Test(unittest.TestCase): input_text = "Hello my name is" EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of your father.\n" MAX_NEW_TOKENS = 10 + def setUp(self): # Models pipeline and tokenizer self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") @@ -98,7 +100,11 @@ def test_save_load(self): The aim of this test is to verify whether if we save and load back a quantized model we retain the same performance. If this test pass people can safely push quantized models on the Hub. """ - pass + with tempfile.TemporaryDirectory() as tmpdirname: + # Save and load 8bit model + self.model_8bit.save_pretrained(tmpdirname) + loaded_model_8bit = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") + print(loaded_model_8bit) @require_torch_multi_gpu def test_multi_gpu_loading(self): From 9b81c6768093636e1b85e922c9ec4840d5648c52 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 09:15:20 +0000 Subject: [PATCH 37/82] clean up a bit --- src/transformers/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d1fc063e4bcd15..ecffb8f29dda8b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1845,9 +1845,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " bitsandbytes-cuda113` for CUDA 11.3)" ) if torch_dtype == "auto" or torch_dtype != torch.float16: - torch_dtype = ( - torch.float16 - ) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + torch_dtype = torch.float16 # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") if device_map is None: raise ValueError( From 659f4276b7021eb35ae37732520a4ec5191e634b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 11:24:40 +0200 Subject: [PATCH 38/82] male style --- src/transformers/modeling_utils.py | 10 ++++++---- tests/mixed_int8/test_mixed_int8.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ecffb8f29dda8b..0c93837ebc2594 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1845,17 +1845,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " bitsandbytes-cuda113` for CUDA 11.3)" ) if torch_dtype == "auto" or torch_dtype != torch.float16: - torch_dtype = torch.float16 # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + torch_dtype = ( + torch.float16 + ) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") if device_map is None: raise ValueError( "A device map needs to be passed to run convert models into mixed-int8 format. Please run" "`.from_pretrained` with `device_map='auto'`" ) - if (from_tf | from_flax): + if from_tf | from_flax: raise ValueError( - "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make sure" - "the weights are in PyTorch format." + "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make" + " surethe weights are in PyTorch format." ) from_pt = not (from_tf | from_flax) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 24aaeef0104641..35652cd264e844 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -31,7 +31,7 @@ @require_accelerate @require_torch @require_torch_gpu -# @slow +@slow class MixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function From 147683ee5b3da97eab6d223d65889ff6c43b68a8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Jul 2022 12:47:45 +0200 Subject: [PATCH 39/82] add more details on doc --- docs/source/en/main_classes/model.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index 741368d8e0d7ee..9cfa922f077740 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -144,6 +144,7 @@ Below are some notes to help you use this module, or follow this demo on Google #### Requirements +- Make sure you run that on a NVIDIA T4 or A100 GPU that supports 8-bit tensor cores. Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores. - Install the correct version of `bitsandbytes` by running: `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding to your `CUDA` version. E.g, if you have `CUDA` 11.3 installed on your machine, run `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113` - Install `accelerate`: From c2e19180ee51d470cae9a70c26eece6108ba126c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 08:11:11 +0000 Subject: [PATCH 40/82] add more tests - still needs to fix 2 tests --- tests/mixed_int8/test_mixed_int8.py | 50 ++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 35652cd264e844..97d8fdf13ba097 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -16,6 +16,8 @@ import tempfile import unittest +import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( require_accelerate, @@ -49,11 +51,17 @@ class MixedInt8Test(unittest.TestCase): MAX_NEW_TOKENS = 10 def setUp(self): - # Models pipeline and tokenizer + # Models and tokenizer self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - return super().setUp() + + def tearDown(self): + # This needs to be done because after each test we need to free the GPU memories otherwise + # we get some unexpected behaviors (CUDA illegal memory access errors) + # This is fine since the model is loaded from the cache + self.model_8bit = None + self.model_fp16 = None def test_memory_footprint(self): r""" @@ -101,14 +109,46 @@ def test_save_load(self): If this test pass people can safely push quantized models on the Hub. """ with tempfile.TemporaryDirectory() as tmpdirname: + # create a dummy tensor + input_ids = torch.LongTensor([[1, 2, 3, 4]]).to(0) + # Save and load 8bit model self.model_8bit.save_pretrained(tmpdirname) loaded_model_8bit = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") - print(loaded_model_8bit) + + # Get both logits + logits_loaded = loaded_model_8bit(input_ids).logits + logits_native = self.model_8bit(input_ids).logits + + # TODO: @younesbelkada understand why the test does not pass + @require_torch_multi_gpu def test_multi_gpu_loading(self): r""" - This tests that the model has been loaded and can be used correctly on a multi-GPU setup + This tests that the model has been loaded and can be used correctly on a multi-GPU setup. + Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice """ - pass + memory_mapping = {0:"1GB", 1:"2GB"} + model_parallel = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto") + + # TODO: @younesbelkada correct this function + def get_list_devices(model): + list_devices = [] + for n, module in model.named_children(): + if len(list(module.children())) > 0: + device = get_list_devices(module) + if device not in list_devices: + list_devices.append(device) + else: + return module.parameters().device.index + return list_devices + + list_devices = get_list_devices(model_parallel) + self.assertEqual(len(list_devices), 2) + + # Do also a quality test and check everything is correct + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = model_parallel.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) + self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + From f6eb9456a8459cec5774bfa22b377c7e0ffc859f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 08:12:09 +0000 Subject: [PATCH 41/82] replace by "or" - could not fix it from GitHub GUI Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0c93837ebc2594..f175b0925f7da3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1854,7 +1854,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "A device map needs to be passed to run convert models into mixed-int8 format. Please run" "`.from_pretrained` with `device_map='auto'`" ) - if from_tf | from_flax: + if from_tf or from_flax: raise ValueError( "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make" " surethe weights are in PyTorch format." From ceff43e3ef0f0f2f6120de64a55232d0864102e2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 09:39:47 +0000 Subject: [PATCH 42/82] refactor a bit testing code + add readme --- tests/mixed_int8/test_mixed_int8.py | 51 +++++++++++++---------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 97d8fdf13ba097..14cb8e554d9e3d 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -52,17 +52,8 @@ class MixedInt8Test(unittest.TestCase): def setUp(self): # Models and tokenizer - self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") - self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - def tearDown(self): - # This needs to be done because after each test we need to free the GPU memories otherwise - # we get some unexpected behaviors (CUDA illegal memory access errors) - # This is fine since the model is loaded from the cache - self.model_8bit = None - self.model_fp16 = None - def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the @@ -70,11 +61,14 @@ def test_memory_footprint(self): """ from bitsandbytes.nn import Int8Params - mem_fp16 = self.model_fp16.get_memory_footprint() - mem_8bit = self.model_8bit.get_memory_footprint() + model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") + model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + mem_fp16 = model_fp16.get_memory_footprint() + mem_8bit = model_8bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) - self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + self.assertTrue(model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) def test_generate_quality(self): r""" @@ -82,9 +76,11 @@ def test_generate_quality(self): Given that we are operating on small numbers + the testing model is relatively small, we might not get the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ + model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) + output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) @@ -94,7 +90,6 @@ def test_pipeline(self): we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything on pipline. """ - pipe = pipeline( model=self.model_name, model_kwargs={"device_map": "auto", "load_in_8bit": True}, @@ -108,17 +103,19 @@ def test_save_load(self): The aim of this test is to verify whether if we save and load back a quantized model we retain the same performance. If this test pass people can safely push quantized models on the Hub. """ + model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + with tempfile.TemporaryDirectory() as tmpdirname: # create a dummy tensor input_ids = torch.LongTensor([[1, 2, 3, 4]]).to(0) # Save and load 8bit model - self.model_8bit.save_pretrained(tmpdirname) + model_8bit.save_pretrained(tmpdirname) loaded_model_8bit = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") # Get both logits logits_loaded = loaded_model_8bit(input_ids).logits - logits_native = self.model_8bit(input_ids).logits + logits_native = model_8bit(input_ids).logits # TODO: @younesbelkada understand why the test does not pass @@ -132,23 +129,21 @@ def test_multi_gpu_loading(self): memory_mapping = {0:"1GB", 1:"2GB"} model_parallel = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto") - # TODO: @younesbelkada correct this function def get_list_devices(model): list_devices = [] - for n, module in model.named_children(): + for _, module in model.named_children(): if len(list(module.children())) > 0: - device = get_list_devices(module) - if device not in list_devices: - list_devices.append(device) + list_devices.extend(get_list_devices(module)) else: - return module.parameters().device.index + # Do a try except since we can encounter Dropout modules that does not + # have any device set + try: + list_devices.append(next(module.parameters()).device.index) + except: + continue return list_devices list_devices = get_list_devices(model_parallel) - self.assertEqual(len(list_devices), 2) - - # Do also a quality test and check everything is correct - encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_parallel.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) - self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + # Check that we have dispatched the model into 2 separate devices + self.assertTrue((1 in list_devices) and (0 in list_devices)) From 55cec55ed575708762b2c33141f459a6680aef27 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 09:47:00 +0000 Subject: [PATCH 43/82] make style --- tests/mixed_int8/README.md | 13 +++++++++++++ tests/mixed_int8/test_mixed_int8.py | 19 +++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) create mode 100644 tests/mixed_int8/README.md diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md new file mode 100644 index 00000000000000..9f0da85a847153 --- /dev/null +++ b/tests/mixed_int8/README.md @@ -0,0 +1,13 @@ +# Testing mixed int8 quantization + +## Hardware requirements + +I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB + +## Virutal envs + +```conda create --name int8-testing python==3.8``` +```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` +```pip install -e ".[dev]"``` +```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114``` +```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` \ No newline at end of file diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 14cb8e554d9e3d..dafe80db61d839 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -62,7 +62,7 @@ def test_memory_footprint(self): from bitsandbytes.nn import Int8Params model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") - model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") mem_fp16 = model_fp16.get_memory_footprint() mem_8bit = model_8bit.get_memory_footprint() @@ -77,7 +77,6 @@ def test_generate_quality(self): the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) @@ -104,11 +103,11 @@ def test_save_load(self): If this test pass people can safely push quantized models on the Hub. """ model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - + with tempfile.TemporaryDirectory() as tmpdirname: # create a dummy tensor input_ids = torch.LongTensor([[1, 2, 3, 4]]).to(0) - + # Save and load 8bit model model_8bit.save_pretrained(tmpdirname) loaded_model_8bit = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") @@ -118,7 +117,6 @@ def test_save_load(self): logits_native = model_8bit(input_ids).logits # TODO: @younesbelkada understand why the test does not pass - @require_torch_multi_gpu def test_multi_gpu_loading(self): @@ -126,8 +124,10 @@ def test_multi_gpu_loading(self): This tests that the model has been loaded and can be used correctly on a multi-GPU setup. Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice """ - memory_mapping = {0:"1GB", 1:"2GB"} - model_parallel = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto") + memory_mapping = {0: "1GB", 1: "2GB"} + model_parallel = AutoModelForCausalLM.from_pretrained( + self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" + ) def get_list_devices(model): list_devices = [] @@ -135,15 +135,14 @@ def get_list_devices(model): if len(list(module.children())) > 0: list_devices.extend(get_list_devices(module)) else: - # Do a try except since we can encounter Dropout modules that does not + # Do a try except since we can encounter Dropout modules that does not # have any device set try: list_devices.append(next(module.parameters()).device.index) except: continue return list_devices - + list_devices = get_list_devices(model_parallel) # Check that we have dispatched the model into 2 separate devices self.assertTrue((1 in list_devices) and (0 in list_devices)) - From 67bf4fbadfc110242d535ce3ab5397d7c0d9f555 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 10:25:36 +0000 Subject: [PATCH 44/82] fix import issue --- tests/mixed_int8/test_mixed_int8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index dafe80db61d839..d5a93e6af1ab49 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -16,8 +16,6 @@ import tempfile import unittest -import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( require_accelerate, @@ -27,6 +25,11 @@ require_torch_multi_gpu, slow, ) +from transformers.utils.import_utils import is_torch_available + + +if is_torch_available(): + import torch @require_bitsandbytes From 56e91470e3a90a968c87c437d94903bdf207ce05 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 1 Aug 2022 15:29:16 +0200 Subject: [PATCH 45/82] Update src/transformers/modeling_utils.py Co-authored-by: Michael Benayoun --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f175b0925f7da3..651a2277588923 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1857,7 +1857,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if from_tf or from_flax: raise ValueError( "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make" - " surethe weights are in PyTorch format." + " sure the weights are in PyTorch format." ) from_pt = not (from_tf | from_flax) From 5a03a86dd2e49c1438a2af060b629e0bf88e0230 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 13:30:26 +0000 Subject: [PATCH 46/82] add few comments --- tests/mixed_int8/test_mixed_int8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index d5a93e6af1ab49..b68bc169e2ff31 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -36,7 +36,7 @@ @require_accelerate @require_torch @require_torch_gpu -@slow +#@slow class MixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -120,6 +120,7 @@ def test_save_load(self): logits_native = model_8bit(input_ids).logits # TODO: @younesbelkada understand why the test does not pass + # This won't work since in the model state dict the quantization statistics are not saved @require_torch_multi_gpu def test_multi_gpu_loading(self): @@ -142,7 +143,7 @@ def get_list_devices(model): # have any device set try: list_devices.append(next(module.parameters()).device.index) - except: + except BaseException: continue return list_devices From 7abb914e726e4f4d3845d4faed2c0085b5e73e13 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 16:05:09 +0200 Subject: [PATCH 47/82] add more doctring + make style --- src/transformers/modeling_utils.py | 3 ++- tests/mixed_int8/test_mixed_int8.py | 31 +---------------------------- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 651a2277588923..de7471e754502d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1724,7 +1724,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P int8_threshold (`int`, *optional*): Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper (#TODO provide - link to paper). Set to `6.0` by default as described in the paper. + link to paper). Any hidden states value that is above this threshold will be considered an outlier and + the operation on those values will be done in fp16. Set to `6.0` by default as described in the paper. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index b68bc169e2ff31..a3e667fecbbdd1 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import tempfile import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline @@ -25,18 +23,13 @@ require_torch_multi_gpu, slow, ) -from transformers.utils.import_utils import is_torch_available - - -if is_torch_available(): - import torch @require_bitsandbytes @require_accelerate @require_torch @require_torch_gpu -#@slow +@slow class MixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -100,28 +93,6 @@ def test_pipeline(self): pipeline_output = pipe(self.input_text) self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT) - def test_save_load(self): - r""" - The aim of this test is to verify whether if we save and load back a quantized model we retain the same performance. - If this test pass people can safely push quantized models on the Hub. - """ - model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - - with tempfile.TemporaryDirectory() as tmpdirname: - # create a dummy tensor - input_ids = torch.LongTensor([[1, 2, 3, 4]]).to(0) - - # Save and load 8bit model - model_8bit.save_pretrained(tmpdirname) - loaded_model_8bit = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") - - # Get both logits - logits_loaded = loaded_model_8bit(input_ids).logits - logits_native = model_8bit(input_ids).logits - - # TODO: @younesbelkada understand why the test does not pass - # This won't work since in the model state dict the quantization statistics are not saved - @require_torch_multi_gpu def test_multi_gpu_loading(self): r""" From 961e57ea85d943b0316f9f69e43dff6e7f0915a4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 16:07:53 +0200 Subject: [PATCH 48/82] more docstring --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index de7471e754502d..bc937d34561a6d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1725,7 +1725,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper (#TODO provide link to paper). Any hidden states value that is above this threshold will be considered an outlier and - the operation on those values will be done in fp16. Set to `6.0` by default as described in the paper. + the operation on those values will be done in fp16. A lower value means more outliers will be detected, + therefore more operations in fp16. Set to `6.0` by default as described in the paper. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. From 1326a42795033410dae6c5a8a07b81f12ee7a41c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 15:13:36 +0000 Subject: [PATCH 49/82] raise error when loaded in 8bit --- src/transformers/utils/bitsandbytes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index a7d1c5008ef1d6..64a1ce9441cb9c 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -58,6 +58,8 @@ class `Int8Params` from `bitsandbytes`. new_value = old_value.to(device) elif isinstance(value, torch.Tensor): new_value = value.to("cpu") + if value.dtype == torch.int8: + raise ValueError(f"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are using `load_in_8bit=True` on float32/float16/bfloat16 weights.") else: new_value = torch.tensor(value, device="cpu") new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) From 9a0051b23d0692ba6dbf7a5ad1953c7586198be3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 1 Aug 2022 17:22:25 +0200 Subject: [PATCH 50/82] make style --- src/transformers/utils/bitsandbytes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 64a1ce9441cb9c..9cb833e3f05a8d 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -59,7 +59,10 @@ class `Int8Params` from `bitsandbytes`. elif isinstance(value, torch.Tensor): new_value = value.to("cpu") if value.dtype == torch.int8: - raise ValueError(f"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are using `load_in_8bit=True` on float32/float16/bfloat16 weights.") + raise ValueError( + "You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are", + " using `load_in_8bit=True` on float32/float16/bfloat16 weights.", + ) else: new_value = torch.tensor(value, device="cpu") new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) From 59f9a5aaf4563f49325e1c0cab46595e76ea9f5a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 Aug 2022 11:05:58 +0200 Subject: [PATCH 51/82] add warning if loaded on CPU --- src/transformers/modeling_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bc937d34561a6d..1af28bdcd3d034 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2169,6 +2169,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P max_memory=max_memory, ) + if load_in_8bit: + if "cpu" in device_map.values() or "disk" in device_map.values(): + raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") + if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors From 418690b26763e5b3250436689486afcca379336d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 Aug 2022 11:07:29 +0200 Subject: [PATCH 52/82] add small sanity check --- tests/mixed_int8/test_mixed_int8.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index a3e667fecbbdd1..4902cd8a97356b 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -121,3 +121,7 @@ def get_list_devices(model): list_devices = get_list_devices(model_parallel) # Check that we have dispatched the model into 2 separate devices self.assertTrue((1 in list_devices) and (0 in list_devices)) + + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + _ = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) From 6d93424992ffdbfc57a8e944344a4d80633a97f1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 Aug 2022 11:24:04 +0200 Subject: [PATCH 53/82] fix small comment --- utils/tests_fetcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 2b33945aa37686..1016e3bd040f80 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -464,7 +464,7 @@ def module_to_test_file(module_fname): "tests/sagemaker/test_single_node_gpu.py", # SageMaker test "tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test "tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test - "tests/mixed_int8/test_mixed_int8.py", # SageMaker test + "tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test ] From d428d8da9b4dea6214bde1e7d3c1c6a6a0844103 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 Aug 2022 16:36:30 +0200 Subject: [PATCH 54/82] add bitsandbytes on dockerfile --- docker/transformers-all-latest-gpu/Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index d82c9f7c777c7e..3e8b0ff89fe1c3 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0" RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate +# Add bitsandbytes for mixed int8 testing +RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 + # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. RUN cd transformers && python3 setup.py develop From 0324f4b1885e51cc93f3f25754b1469f7538bbc7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 Aug 2022 23:55:41 +0200 Subject: [PATCH 55/82] Improve documentation - improve documentation from comments --- docs/source/en/main_classes/model.mdx | 12 +++++++++++- src/transformers/modeling_utils.py | 12 ++++++++---- src/transformers/utils/bitsandbytes.py | 5 ++++- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index 9cfa922f077740..4498edad595775 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -138,6 +138,10 @@ Due to Pytorch design, this functionality is only available for floating dtypes. From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization. This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. + +Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters). +Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). + Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature. Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing) @@ -157,7 +161,13 @@ After carefully installing the required libraries, the way to load your mixed 8- model_name = "bigscience/bloom-2b5" model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) ``` -The implementation supports multi-GPU setup thanks to `accelerate` as backend. +The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows: +(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) +``` +max_memory_mapping={0:"2GB", 1:"3GB"} +model_name = "bigscience/bloom-2b5" +model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping) +``` ## ModuleUtilsMixin diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1af28bdcd3d034..9355a64cbb7536 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1723,10 +1723,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P not compiled and adapted for CPUs. int8_threshold (`int`, *optional*): Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as - described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper (#TODO provide - link to paper). Any hidden states value that is above this threshold will be considered an outlier and - the operation on those values will be done in fp16. A lower value means more outliers will be detected, - therefore more operations in fp16. Set to `6.0` by default as described in the paper. + described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden + states value that is above this threshold will be considered an outlier and the operation on those + values will be done in fp16. Values are usually normally distributed, that is, most values are in the + range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently + distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 + quantization works well for values of magnitude ~5, but beyond that, there is a significant performance + penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models + (small models, fine-tuning). subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 9cb833e3f05a8d..c795c3010880cd 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -92,7 +92,10 @@ def replace_8bit_linear(model, threshold=6.0): The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no - CPU/GPU memory is required to run this function. + CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a + matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 + (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no + predictive degradation is possible for very large models (>=176B parameters). Parameters: model (`torch.nn.Module`): From 70ad8cbe77a1273549c719c9839ed0fdcddca65c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 Aug 2022 16:16:49 +0200 Subject: [PATCH 56/82] add few comments --- tests/mixed_int8/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md index 9f0da85a847153..707fae87cd1fcf 100644 --- a/tests/mixed_int8/README.md +++ b/tests/mixed_int8/README.md @@ -10,4 +10,16 @@ I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB ```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` ```pip install -e ".[dev]"``` ```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114``` -```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` \ No newline at end of file +```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` + +## Trouble shooting + +### Check driver settings: + +``` +nvcc --version +``` + +``` +ls -l $CONDA_PREFIX/lib/libcudart.so +``` \ No newline at end of file From 1eedb902c14cbdbc8e7e33a89cf449ae028ab9b7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 Aug 2022 15:07:08 +0000 Subject: [PATCH 57/82] slow tests pass on the VM but not on the CI VM --- tests/mixed_int8/README.md | 11 ++++++++++- tests/mixed_int8/test_mixed_int8.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md index 9f0da85a847153..35b7bed3b15a0a 100644 --- a/tests/mixed_int8/README.md +++ b/tests/mixed_int8/README.md @@ -10,4 +10,13 @@ I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB ```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` ```pip install -e ".[dev]"``` ```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114``` -```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` \ No newline at end of file +```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` + +## Trobleshooting + +```conda create --name int8-testing python==3.8``` +```pip install -i https://test.pypi.org/simple/ bitsandbytes``` +```conda install pytorch torchvision torchaudio -c pytorch``` +```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` +```pip install -e ".[dev]"``` +```pip install git+https://github.com/huggingface/accelerate.git@b52b793ea8bac108ba61192eead3cf11ca02433d``` \ No newline at end of file diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 4902cd8a97356b..fb6ce090eebe11 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -25,6 +25,7 @@ ) + @require_bitsandbytes @require_accelerate @require_torch @@ -66,6 +67,9 @@ def test_memory_footprint(self): self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertTrue(model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + del model_8bit + del model_fp16 + def test_generate_quality(self): r""" Test the generation quality of the quantized model and see that we are matching the expected output. @@ -79,6 +83,8 @@ def test_generate_quality(self): self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + del model_8bit + def test_pipeline(self): r""" The aim of this test is to verify that the mixed int8 is compatible with `pipeline` from transformers. Since @@ -86,10 +92,15 @@ def test_pipeline(self): on pipline. """ pipe = pipeline( + "text-generation", model=self.model_name, model_kwargs={"device_map": "auto", "load_in_8bit": True}, max_new_tokens=self.MAX_NEW_TOKENS, ) + # Needs a first forward pass to get the statistics + _ = pipe(self.input_text) + + # Real second forward pass pipeline_output = pipe(self.input_text) self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT) @@ -124,4 +135,10 @@ def get_list_devices(model): # Check that inference pass works on the model encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # First dummy batch to get the statistics _ = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + # Second real batch + output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) \ No newline at end of file From 8100d034ca305c1feb11ae93bbcbbcbba5c4bc53 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 Aug 2022 15:09:28 +0000 Subject: [PATCH 58/82] Fix merge conflict --- tests/mixed_int8/README.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md index edd6ed2f7c2794..34cecee6bf5b4a 100644 --- a/tests/mixed_int8/README.md +++ b/tests/mixed_int8/README.md @@ -12,7 +12,7 @@ I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB ```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114``` ```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` -<<<<<<< HEAD + ## Trobleshooting ```conda create --name int8-testing python==3.8``` @@ -21,8 +21,6 @@ I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB ```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` ```pip install -e ".[dev]"``` ```pip install git+https://github.com/huggingface/accelerate.git@b52b793ea8bac108ba61192eead3cf11ca02433d``` -======= -## Trouble shooting ### Check driver settings: @@ -33,4 +31,3 @@ nvcc --version ``` ls -l $CONDA_PREFIX/lib/libcudart.so ``` ->>>>>>> 70ad8cbe77a1273549c719c9839ed0fdcddca65c From c9589f65c508d674151d92b0520d862c9e5702c1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 Aug 2022 16:00:11 +0000 Subject: [PATCH 59/82] make style --- tests/mixed_int8/test_mixed_int8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index fb6ce090eebe11..953d7c92e6c63d 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -25,7 +25,6 @@ ) - @require_bitsandbytes @require_accelerate @require_torch @@ -141,4 +140,4 @@ def get_list_devices(model): # Second real batch output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) - self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) \ No newline at end of file + self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) From eb9a26d3357e3236bacc9388135560bb2839e067 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 Aug 2022 10:51:24 +0000 Subject: [PATCH 60/82] another test should pass on a multi gpu setup --- tests/mixed_int8/README.md | 4 ++ tests/mixed_int8/test_mixed_int8.py | 71 +++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md index 34cecee6bf5b4a..4960d2fd218320 100644 --- a/tests/mixed_int8/README.md +++ b/tests/mixed_int8/README.md @@ -31,3 +31,7 @@ nvcc --version ``` ls -l $CONDA_PREFIX/lib/libcudart.so ``` + +### Recurrent bugs + +Sometimes you have to run a "dummy" inference pass when dealing with a multi-GPU setup. Checkout the ```test_multi_gpu_loading``` and the ```test_pipeline``` functions. \ No newline at end of file diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 953d7c92e6c63d..f4d4d71710b5f7 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest +import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( require_accelerate, @@ -30,7 +33,7 @@ @require_torch @require_torch_gpu @slow -class MixedInt8Test(unittest.TestCase): +class BaseMixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) @@ -50,6 +53,26 @@ def setUp(self): # Models and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + +class MixedInt8Test(BaseMixedInt8Test): + def setUp(self): + super().setUp() + + # Models and tokenizer + self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the @@ -57,17 +80,11 @@ def test_memory_footprint(self): """ from bitsandbytes.nn import Int8Params - model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto") - model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - - mem_fp16 = model_fp16.get_memory_footprint() - mem_8bit = model_8bit.get_memory_footprint() + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) - self.assertTrue(model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) - - del model_8bit - del model_fp16 + self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) def test_generate_quality(self): r""" @@ -75,14 +92,25 @@ def test_generate_quality(self): Given that we are operating on small numbers + the testing model is relatively small, we might not get the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ - model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].cuda(), max_new_tokens=10) + output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - del model_8bit + +class MixedInt8TestPipeline(BaseMixedInt8Test): + def setUp(self): + super().setUp() + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.pipe + + gc.collect() + torch.cuda.empty_cache() def test_pipeline(self): r""" @@ -90,25 +118,32 @@ def test_pipeline(self): we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything on pipline. """ - pipe = pipeline( + # self._clear_cuda_cache() + self.pipe = pipeline( "text-generation", model=self.model_name, model_kwargs={"device_map": "auto", "load_in_8bit": True}, max_new_tokens=self.MAX_NEW_TOKENS, ) # Needs a first forward pass to get the statistics - _ = pipe(self.input_text) + _ = self.pipe(self.input_text) # Real second forward pass - pipeline_output = pipe(self.input_text) + pipeline_output = self.pipe(self.input_text) self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT) - @require_torch_multi_gpu + +@require_torch_multi_gpu +class MixedInt8TestMultiGpu(BaseMixedInt8Test): + def setUp(self): + super().setUp() + def test_multi_gpu_loading(self): r""" This tests that the model has been loaded and can be used correctly on a multi-GPU setup. Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice """ + memory_mapping = {0: "1GB", 1: "2GB"} model_parallel = AutoModelForCausalLM.from_pretrained( self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" From 838e2a93d5c11291058c71751b6977def954cedb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 Aug 2022 11:42:15 +0000 Subject: [PATCH 61/82] fix bad import in testing file --- tests/mixed_int8/test_mixed_int8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index f4d4d71710b5f7..b16042d7c9af2c 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -15,10 +15,9 @@ import gc import unittest -import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( + is_torch_available, require_accelerate, require_bitsandbytes, require_torch, @@ -28,6 +27,10 @@ ) +if is_torch_available(): + import torch + + @require_bitsandbytes @require_accelerate @require_torch From c4a1e9b6f0d838b486870fe0d6f6dd97194705c8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 Aug 2022 14:52:56 +0000 Subject: [PATCH 62/82] Fix slow tests - remove dummy batches - no more CUDA illegal memory errors --- tests/mixed_int8/test_mixed_int8.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index b16042d7c9af2c..19dd05cbf5acd4 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -128,8 +128,6 @@ def test_pipeline(self): model_kwargs={"device_map": "auto", "load_in_8bit": True}, max_new_tokens=self.MAX_NEW_TOKENS, ) - # Needs a first forward pass to get the statistics - _ = self.pipe(self.input_text) # Real second forward pass pipeline_output = self.pipe(self.input_text) @@ -173,9 +171,6 @@ def get_list_devices(model): # Check that inference pass works on the model encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - # First dummy batch to get the statistics - _ = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) - # Second real batch output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) From 31fce94e8a3983dfa65222311b340460ccff05f7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 Aug 2022 16:58:21 +0000 Subject: [PATCH 63/82] odify dockerfile --- docker/transformers-all-latest-gpu/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 3e8b0ff89fe1c3..890ec8a441270c 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -46,7 +46,7 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0" RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate # Add bitsandbytes for mixed int8 testing -RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 +RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.2 # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. From 3e4a2a4f26263f3380772ea29557220f3a5e1d6c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 5 Aug 2022 12:53:51 +0200 Subject: [PATCH 64/82] Update docs/source/en/main_classes/model.mdx --- docs/source/en/main_classes/model.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index 4498edad595775..a17de927ab0236 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -164,7 +164,7 @@ model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows: (If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) ``` -max_memory_mapping={0:"2GB", 1:"3GB"} +max_memory_mapping={0:"1GB", 1:"2GB"} model_name = "bigscience/bloom-2b5" model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping) ``` From fdf37b3f9918eb117896db46c7a50b6f3b34ab28 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 5 Aug 2022 17:57:59 +0200 Subject: [PATCH 65/82] Update Dockerfile --- docker/transformers-all-latest-gpu/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 890ec8a441270c..9b8f94aaffb252 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -47,6 +47,8 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc # Add bitsandbytes for mixed int8 testing RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.2 +# add correct LD_LIB path +ENV LD_LIBRARY_PATH /usr/local/cuda/lib64 # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. From 53e0b2ecf9be35c09fbe7c1c93626b82d0b2e139 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Sun, 7 Aug 2022 10:40:50 +0200 Subject: [PATCH 66/82] Update model.mdx --- docs/source/en/main_classes/model.mdx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index a17de927ab0236..6c182ee879c8c9 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -133,12 +133,14 @@ model = AutoModel.from_config(config) Due to Pytorch design, this functionality is only available for floating dtypes. -### Mixed int8 quantization đŸĒļ +### Mixed int8 quantization From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization. This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. +![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png) + Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters). Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning). From 91364c480e42c73821562e14f6039c7826c66c48 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 8 Aug 2022 07:46:00 +0200 Subject: [PATCH 67/82] Update Dockerfile --- docker/transformers-all-latest-gpu/Dockerfile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 9b8f94aaffb252..4dd5c1e8e03ea3 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -46,9 +46,7 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0" RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate # Add bitsandbytes for mixed int8 testing -RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.2 -# add correct LD_LIB path -ENV LD_LIBRARY_PATH /usr/local/cuda/lib64 +RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5 # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. From 5ea8976d0bcf9cd205aa5bcd9eceb1cfe42e42c5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 8 Aug 2022 10:32:41 +0200 Subject: [PATCH 68/82] Apply suggestions from code review --- docs/source/en/main_classes/model.mdx | 2 +- tests/mixed_int8/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index 6c182ee879c8c9..eb9a3050f95687 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -152,7 +152,7 @@ Below are some notes to help you use this module, or follow this demo on Google - Make sure you run that on a NVIDIA T4 or A100 GPU that supports 8-bit tensor cores. Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores. - Install the correct version of `bitsandbytes` by running: -`pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX` corresponding to your `CUDA` version. E.g, if you have `CUDA` 11.3 installed on your machine, run `pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113` +`pip install -i https://test.pypi.org/simple/ bitsandbytes` - Install `accelerate`: `pip install accelerate` diff --git a/tests/mixed_int8/README.md b/tests/mixed_int8/README.md index 4960d2fd218320..c0173bed7a6b7a 100644 --- a/tests/mixed_int8/README.md +++ b/tests/mixed_int8/README.md @@ -9,7 +9,7 @@ I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB ```conda create --name int8-testing python==3.8``` ```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit``` ```pip install -e ".[dev]"``` -```pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114``` +```pip install -i https://test.pypi.org/simple/ bitsandbytes``` ```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1``` From 8b72d0890dac7322705649efb47c07c5fa82dcae Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 8 Aug 2022 21:50:23 +0000 Subject: [PATCH 69/82] few modifications - lm head can stay on disk/cpu - change model name so that test pass --- src/transformers/modeling_utils.py | 5 ++++- tests/mixed_int8/test_mixed_int8.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0109d0ef5302dd..15486737d38392 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2216,8 +2216,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if load_in_8bit: - if "cpu" in device_map.values() or "disk" in device_map.values(): + # The LM head can stay on disk / CPU + device_map_without_lm_head = {key: device_map[key] for key in device_map.keys() if key != "lm_head"} + if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") + del device_map_without_lm_head if from_tf: if resolved_archive_file.endswith(".index"): diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 19dd05cbf5acd4..933c80ea0ec596 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -41,7 +41,7 @@ class BaseMixedInt8Test(unittest.TestCase): # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) # Therefore here we use only bloom-1b3 to test our module - model_name = "bigscience/bloom-1b3" + model_name = "bigscience/bloom-1b7" # Constant values EXPECTED_RELATIVE_DIFFERENCE = ( From c6c139f28dde632a00934e05074ee57a59dfc660 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 8 Aug 2022 22:03:48 +0000 Subject: [PATCH 70/82] change test value - change test value to the correct output - torch bmm changed to baddmm in bloom modeling when merging --- tests/mixed_int8/test_mixed_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 933c80ea0ec596..129954fb72b891 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -49,7 +49,7 @@ class BaseMixedInt8Test(unittest.TestCase): ) input_text = "Hello my name is" - EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of your father.\n" + EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of the family.\n" MAX_NEW_TOKENS = 10 def setUp(self): From 0d5bc2b75c55bdd1eb8fedc8f72cd26878e5cfec Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 14:00:34 +0000 Subject: [PATCH 71/82] modify installation guidelines --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5fab363d88a6c9..afa6eda221bb3f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1868,9 +1868,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if not (is_accelerate_available() and is_bitsandbytes_available()): raise ImportError( "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" - " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` with `XXX`" - " correspondingto your CUDA version (e.g. `pip install -i https://test.pypi.org/simple/" - " bitsandbytes-cuda113` for CUDA 11.3)" + " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or" + " pip install bitsandbytes` " ) if torch_dtype == "auto" or torch_dtype != torch.float16: torch_dtype = ( From bc8f332ad5043c71162b5893b1728f0f162bdeba Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Aug 2022 16:12:22 +0200 Subject: [PATCH 72/82] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/main_classes/model.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index eb9a3050f95687..e2ea796306db7f 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -159,13 +159,13 @@ Below are some notes to help you use this module, or follow this demo on Google #### Running mixed-int8 models After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows: -``` +```py model_name = "bigscience/bloom-2b5" model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) ``` The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows: (If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) -``` +```py max_memory_mapping={0:"1GB", 1:"2GB"} model_name = "bigscience/bloom-2b5" model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping) From 5adcadc011bf8d0b985c893c3638827a10a37130 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Aug 2022 16:13:36 +0200 Subject: [PATCH 73/82] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index afa6eda221bb3f..1fe3b35be715d5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1590,7 +1590,7 @@ def get_memory_footprint(self, return_buffers=True): PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 Arguments: - return_buffers (`bool`, *optional*): + return_buffers (`bool`, *optional*, defaults to `True`): Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 @@ -1730,13 +1730,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. - load_in_8bit (`bool`, *optional*): + load_in_8bit (`bool`, *optional*, defaults to `False`): If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please install `bitsandbytes` compiled with your CUDA version by running `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are not compiled and adapted for CPUs. - int8_threshold (`int`, *optional*): + int8_threshold (`float`, *optional*, defaults to 6): Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden states value that is above this threshold will be considered an outlier and the operation on those From bb00e7a8b514043e91f78eb8dfd3e145a64364ee Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Aug 2022 16:15:20 +0200 Subject: [PATCH 74/82] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 5 ++--- src/transformers/utils/bitsandbytes.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1fe3b35be715d5..cebf0cbd85860a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1872,9 +1872,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " pip install bitsandbytes` " ) if torch_dtype == "auto" or torch_dtype != torch.float16: - torch_dtype = ( - torch.float16 - ) # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + torch_dtype = torch.float16 logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") if device_map is None: raise ValueError( diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index c795c3010880cd..8e16b31c5b70c4 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -100,7 +100,7 @@ def replace_8bit_linear(model, threshold=6.0): Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. - threshold (`float`, *optional*): + threshold (`float`, *optional*, defaults to 6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. """ From 630b4f757557f5b2c9b0524071003f1fadb826d3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 14:18:26 +0000 Subject: [PATCH 75/82] replace `n`by `name` --- src/transformers/utils/bitsandbytes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 8e16b31c5b70c4..307bcd3a98478a 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -104,13 +104,13 @@ def replace_8bit_linear(model, threshold=6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. """ - for n, module in model.named_children(): + for name, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module, threshold) - if isinstance(module, nn.Linear) and n != "lm_head": + if isinstance(module, nn.Linear) and name != "lm_head": with init_empty_weights(): - model._modules[n] = bnb.nn.Linear8bitLt( + model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, module.out_features, module.bias is not None, From a92564137ec21e665c8e241251d5a2c565113b20 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 14:22:15 +0000 Subject: [PATCH 76/82] merge `load_in_8bit` and `low_cpu_mem_usage` --- src/transformers/modeling_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cebf0cbd85860a..3f215a8b1c69c0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2125,16 +2125,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts - elif load_in_8bit: - init_contexts = [init_empty_weights()] # Force enable init empty weights - logger.info("Detected 8-bit loading: activating 8-bit loading for this model") - elif low_cpu_mem_usage: + elif load_in_8bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) if load_in_8bit: + logger.info("Detected 8-bit loading: activating 8-bit loading for this model") model = replace_8bit_linear(model, threshold=int8_threshold) if isinstance(device_map, str): From 279b8c45bcf3893e1b4695d32ff2e1083c6bda9d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 15:17:58 +0000 Subject: [PATCH 77/82] first try - keep the lm head in full precision --- src/transformers/modeling_utils.py | 11 ++++++--- src/transformers/utils/bitsandbytes.py | 34 +++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3f215a8b1c69c0..83fad202ab4d0c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -85,7 +85,7 @@ get_balanced_memory = None if is_bitsandbytes_available(): - from .utils.bitsandbytes import replace_8bit_linear, set_module_8bit_tensor_to_device + from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device logger = logging.get_logger(__name__) @@ -2133,7 +2133,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if load_in_8bit: logger.info("Detected 8-bit loading: activating 8-bit loading for this model") - model = replace_8bit_linear(model, threshold=int8_threshold) + + # We never convert lm_head or any last modules for numerical stability reasons + modules_to_not_convert = get_key_to_not_convert(model) + model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) if isinstance(device_map, str): if model._no_split_modules is None: @@ -2165,7 +2168,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if load_in_8bit: # The LM head can stay on disk / CPU - device_map_without_lm_head = {key: device_map[key] for key in device_map.keys() if key != "lm_head"} + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert + } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") del device_map_without_lm_head diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 307bcd3a98478a..6c34f337bdabe2 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -82,7 +82,7 @@ class `Int8Params` from `bitsandbytes`. module._parameters[tensor_name] = new_value -def replace_8bit_linear(model, threshold=6.0): +def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -103,12 +103,15 @@ def replace_8bit_linear(model, threshold=6.0): threshold (`float`, *optional*, defaults to 6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. + modules_to_not_convert (`str`, *optional*, defaults to `lm_head`): + Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. """ for name, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold) + replace_8bit_linear(module, threshold, modules_to_not_convert) - if isinstance(module, nn.Linear) and name != "lm_head": + if isinstance(module, nn.Linear) and name != modules_to_not_convert: with init_empty_weights(): model._modules[name] = bnb.nn.Linear8bitLt( module.in_features, @@ -118,3 +121,28 @@ def replace_8bit_linear(model, threshold=6.0): threshold=threshold, ) return model + + +def get_key_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any + For example for CausalLM modules we may want to keep the lm_head in full precision + for numerical stability reasons. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + + # Compute the number of parameters between the base model and the full model + n_params_base = sum(p.numel() for p in model.base_model.parameters() if p.requires_grad) + n_params_full = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # if they have the same number of parameters they are the same model - no attached head + if n_params_base == n_params_full: + return "" + + # otherwise they have an attached head + list_modules = list(model.named_parameters()) + last_name = list_modules[-1][0] + return last_name.split(".")[0] From e49a2ea29262a57f90c2db8c26f5655af4ad3aea Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 15:20:40 +0000 Subject: [PATCH 78/82] better check - check the attribute `base_model_prefix` instead of computing the number of parameters --- src/transformers/utils/bitsandbytes.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 6c34f337bdabe2..2b509d2233a850 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -133,13 +133,7 @@ def get_key_to_not_convert(model): model (`torch.nn.Module`): Input model """ - - # Compute the number of parameters between the base model and the full model - n_params_base = sum(p.numel() for p in model.base_model.parameters() if p.requires_grad) - n_params_full = sum(p.numel() for p in model.parameters() if p.requires_grad) - - # if they have the same number of parameters they are the same model - no attached head - if n_params_base == n_params_full: + if not hasattr(model, model.base_model_prefix): return "" # otherwise they have an attached head From 5718d78ae2f7bb32c5afab6829e2545c9f4b06c1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 15:37:42 +0000 Subject: [PATCH 79/82] added more tests --- tests/mixed_int8/test_mixed_int8.py | 38 ++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 129954fb72b891..4f5370e2fea23f 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -15,7 +15,7 @@ import gc import unittest -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer, pipeline from transformers.testing_utils import ( is_torch_available, require_accelerate, @@ -101,6 +101,42 @@ def test_generate_quality(self): self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) +class MixedInt8ModelClassesTest(BaseMixedInt8Test): + def setUp(self): + super().setUp() + # model_name + self.model_name = "bigscience/bloom-560m" + # Models and tokenizer + self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + self.sequence_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + + def tearDown(self): + r""" + TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to + avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 + """ + del self.base_model + del self.sequence_model + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_correct_head_class(self): + r""" + A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification) + are kept in their native class. + """ + from bitsandbytes.nn import Int8Params + + # last param of a base model should be a linear8bit module + self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + + # Other heads should be nn.Parameter + self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) + self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) + class MixedInt8TestPipeline(BaseMixedInt8Test): def setUp(self): super().setUp() From a40667a0bf5709cd27d58e43e337328ec5e17240 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Aug 2022 17:44:51 +0200 Subject: [PATCH 80/82] Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/utils/bitsandbytes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 2b509d2233a850..22d52e6e04d208 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -133,6 +133,7 @@ def get_key_to_not_convert(model): model (`torch.nn.Module`): Input model """ + # Ignore this for base models (BertModel, GPT2Model, etc.) if not hasattr(model, model.base_model_prefix): return "" From 61faa282b68fd0bbfe509d0bce685f78d17be81b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 Aug 2022 17:48:42 +0200 Subject: [PATCH 81/82] Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit --- docs/source/en/main_classes/model.mdx | 6 ++++-- src/transformers/utils/bitsandbytes.py | 5 ++--- tests/mixed_int8/test_mixed_int8.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index e2ea796306db7f..31cbceecf01348 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -166,9 +166,11 @@ model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows: (If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) ```py -max_memory_mapping={0:"1GB", 1:"2GB"} +max_memory_mapping = {0: "1GB", 1: "2GB"} model_name = "bigscience/bloom-2b5" -model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping) +model_8bit = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping +) ``` diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 22d52e6e04d208..61b018e0bfe69d 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -125,9 +125,8 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): def get_key_to_not_convert(model): r""" - An utility function to get the key of the module to keep in full precision if any - For example for CausalLM modules we may want to keep the lm_head in full precision - for numerical stability reasons. + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. Parameters: model (`torch.nn.Module`): diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 4f5370e2fea23f..0cd7ca16411c19 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -15,7 +15,7 @@ import gc import unittest -from transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline from transformers.testing_utils import ( is_torch_available, require_accelerate, @@ -108,7 +108,9 @@ def setUp(self): self.model_name = "bigscience/bloom-560m" # Models and tokenizer self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - self.sequence_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + self.sequence_model = AutoModelForSequenceClassification.from_pretrained( + self.model_name, load_in_8bit=True, device_map="auto" + ) self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") def tearDown(self): @@ -122,11 +124,11 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - + def test_correct_head_class(self): r""" - A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification) - are kept in their native class. + A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification) + are kept in their native class. """ from bitsandbytes.nn import Int8Params @@ -137,6 +139,7 @@ def test_correct_head_class(self): self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) + class MixedInt8TestPipeline(BaseMixedInt8Test): def setUp(self): super().setUp() From 36fbbd2095b521ba53f906ee2a9c67d9dd97a629 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 07:15:13 +0200 Subject: [PATCH 82/82] improve documentation - fix typos for installation - change title in the documentation --- docs/source/en/main_classes/model.mdx | 8 ++++---- src/transformers/utils/bitsandbytes.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/main_classes/model.mdx b/docs/source/en/main_classes/model.mdx index 31cbceecf01348..10f81e55d74506 100644 --- a/docs/source/en/main_classes/model.mdx +++ b/docs/source/en/main_classes/model.mdx @@ -133,10 +133,10 @@ model = AutoModel.from_config(config) Due to Pytorch design, this functionality is only available for floating dtypes. -### Mixed int8 quantization +### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code. -For models trained in half-precision (aka, either `float16` or `bfloat16`). This method aims to reduce `nn.Linear` size by 2, without affecting too much quality by operating on the outliers in half-precision. Therefore the method is termed as mixed 8bit quantization. +For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision. This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models. ![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png) @@ -150,7 +150,7 @@ Below are some notes to help you use this module, or follow this demo on Google #### Requirements -- Make sure you run that on a NVIDIA T4 or A100 GPU that supports 8-bit tensor cores. Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores. +- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores. - Install the correct version of `bitsandbytes` by running: `pip install -i https://test.pypi.org/simple/ bitsandbytes` - Install `accelerate`: @@ -167,7 +167,7 @@ The implementation supports multi-GPU setup thanks to `accelerate` as backend. I (If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`) ```py max_memory_mapping = {0: "1GB", 1: "2GB"} -model_name = "bigscience/bloom-2b5" +model_name = "bigscience/bloom-3b" model_8bit = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping ) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 61b018e0bfe69d..ee4e52d421fd09 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -88,7 +88,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) + bitsandbytes` The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no