Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bitsandbytes - Linear8bitLt integration into transformers models #17901

Merged
merged 91 commits into from
Aug 10, 2022
Merged
Changes from 4 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
ed1dd12
first commit
younesbelkada Jun 24, 2022
b9d0da6
correct replace function
younesbelkada Jun 24, 2022
dd9a464
add final changes
younesbelkada Jun 27, 2022
35e1534
clean up a bit
younesbelkada Jun 27, 2022
d01822b
add bitsandbytes dependencies
younesbelkada Jul 7, 2022
839c9cd
working version
younesbelkada Jul 12, 2022
93a5ac6
Merge branch 'main' of https://github.com/huggingface/transformers in…
younesbelkada Jul 12, 2022
42a6845
small fix
younesbelkada Jul 12, 2022
97f64f8
small fix
younesbelkada Jul 12, 2022
a1fe7fc
fix import issues
younesbelkada Jul 12, 2022
05739e3
Apply suggestions from code review
younesbelkada Jul 12, 2022
7816ef9
refactor a bit
younesbelkada Jul 12, 2022
1155549
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 12, 2022
b222b9a
reformat docstring
younesbelkada Jul 12, 2022
32f48cd
Update src/transformers/__init__.py
younesbelkada Jul 12, 2022
e116e21
revert bad formatting
younesbelkada Jul 12, 2022
39c46a0
change to bitsandbytes
younesbelkada Jul 12, 2022
b92c25c
refactor a bit
younesbelkada Jul 12, 2022
3779f5d
more refactoring
younesbelkada Jul 13, 2022
b41c250
small hack to make it work
younesbelkada Jul 19, 2022
311dcbf
Merge branch 'main' into integration-8bit
younesbelkada Jul 20, 2022
c91a58e
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
db16cf8
Update src/transformers/modeling_utils.py
younesbelkada Jul 20, 2022
be6ce29
revmoe the small hack
younesbelkada Jul 21, 2022
848d64d
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 21, 2022
15a81e0
modify utils file
younesbelkada Jul 21, 2022
514758d
make style + refactor a bit
younesbelkada Jul 22, 2022
a09e055
create correctly device map
younesbelkada Jul 25, 2022
387aa1e
add correct dtype for device map creation
younesbelkada Jul 25, 2022
7199292
Apply suggestions from code review
younesbelkada Jul 27, 2022
a4c19c1
apply suggestions
younesbelkada Jul 27, 2022
a5cd157
add docstring
younesbelkada Jul 27, 2022
9bd326b
add docstring
younesbelkada Jul 27, 2022
6e4fee6
- added more documentation
younesbelkada Jul 28, 2022
9e121b7
few modifs
younesbelkada Jul 28, 2022
a2ac688
added colab link
younesbelkada Jul 28, 2022
ac370b9
add test architecture + docstring a bit
younesbelkada Jul 28, 2022
27e9486
refactor a bit testing class
younesbelkada Jul 28, 2022
a0db982
make style + refactor a bit
younesbelkada Jul 28, 2022
21bd590
enhance checks
younesbelkada Jul 29, 2022
f1fcf77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Jul 29, 2022
9b81c67
clean up a bit
younesbelkada Jul 29, 2022
659f427
male style
younesbelkada Jul 29, 2022
147683e
add more details on doc
younesbelkada Jul 29, 2022
c2e1918
add more tests
younesbelkada Aug 1, 2022
f6eb945
replace by "or"
younesbelkada Aug 1, 2022
ceff43e
refactor a bit testing code + add readme
younesbelkada Aug 1, 2022
55cec55
make style
younesbelkada Aug 1, 2022
67bf4fb
fix import issue
younesbelkada Aug 1, 2022
56e9147
Update src/transformers/modeling_utils.py
younesbelkada Aug 1, 2022
5a03a86
add few comments
younesbelkada Aug 1, 2022
7abb914
add more doctring + make style
younesbelkada Aug 1, 2022
961e57e
more docstring
younesbelkada Aug 1, 2022
1326a42
raise error when loaded in 8bit
younesbelkada Aug 1, 2022
9a0051b
make style
younesbelkada Aug 1, 2022
59f9a5a
add warning if loaded on CPU
younesbelkada Aug 2, 2022
418690b
add small sanity check
younesbelkada Aug 2, 2022
6d93424
fix small comment
younesbelkada Aug 2, 2022
d428d8d
add bitsandbytes on dockerfile
younesbelkada Aug 2, 2022
0324f4b
Improve documentation
younesbelkada Aug 2, 2022
af229c4
Merge branch 'main' into integration-8bit
younesbelkada Aug 2, 2022
70ad8cb
add few comments
younesbelkada Aug 3, 2022
1eedb90
slow tests pass on the VM but not on the CI VM
younesbelkada Aug 3, 2022
163ef77
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 3, 2022
8100d03
Fix merge conflict
younesbelkada Aug 3, 2022
c9589f6
make style
younesbelkada Aug 3, 2022
eb9a26d
another test should pass on a multi gpu setup
younesbelkada Aug 4, 2022
838e2a9
fix bad import in testing file
younesbelkada Aug 4, 2022
c4a1e9b
Fix slow tests
younesbelkada Aug 4, 2022
31fce94
odify dockerfile
younesbelkada Aug 4, 2022
3e4a2a4
Update docs/source/en/main_classes/model.mdx
younesbelkada Aug 5, 2022
fdf37b3
Update Dockerfile
younesbelkada Aug 5, 2022
53e0b2e
Update model.mdx
younesbelkada Aug 7, 2022
91364c4
Update Dockerfile
younesbelkada Aug 8, 2022
5ea8976
Apply suggestions from code review
younesbelkada Aug 8, 2022
8b72d08
few modifications
younesbelkada Aug 8, 2022
32a4863
Merge branch 'main' into integration-8bit
younesbelkada Aug 8, 2022
c6c139f
change test value
younesbelkada Aug 8, 2022
3d3224f
Merge remote-tracking branch 'upstream/main' into integration-8bit
younesbelkada Aug 9, 2022
0d5bc2b
modify installation guidelines
younesbelkada Aug 9, 2022
bc8f332
Apply suggestions from code review
younesbelkada Aug 9, 2022
5adcadc
Apply suggestions from code review
younesbelkada Aug 9, 2022
bb00e7a
Apply suggestions from code review
younesbelkada Aug 9, 2022
630b4f7
replace `n`by `name`
younesbelkada Aug 9, 2022
a925641
merge `load_in_8bit` and `low_cpu_mem_usage`
younesbelkada Aug 9, 2022
279b8c4
first try - keep the lm head in full precision
younesbelkada Aug 9, 2022
e49a2ea
better check
younesbelkada Aug 9, 2022
5718d78
added more tests
younesbelkada Aug 9, 2022
a40667a
Update src/transformers/utils/bitsandbytes.py
younesbelkada Aug 9, 2022
61faa28
Merge branch 'integration-8bit' of https://github.com/younesbelkada/t…
younesbelkada Aug 9, 2022
36fbbd2
improve documentation
younesbelkada Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def forward(self, input):
return input


def replace_8bit_linear(model):
import bitsandbytes as bnb

for n, module in model.named_children():
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if len(list(module.children())) > 0:
replace_8bit_linear(module)

if isinstance(module, nn.Linear):
with init_empty_weights():
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features, module.out_features, module.bias is not None, has_fp16_weights=True
)
return model


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -1776,6 +1791,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if device_map is not None:
if low_cpu_mem_usage is None:
Expand Down Expand Up @@ -2061,12 +2077,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif load_in_8bit:
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

init_contexts = [init_empty_weights()] # Force enable init empty weights

logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

if load_in_8bit:
model = replace_8bit_linear(model)

if device_map == "auto":
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
Expand Down Expand Up @@ -2126,6 +2150,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
)

# make sure token embedding weights are still tied if needed
Expand Down Expand Up @@ -2165,6 +2190,7 @@ def _load_pretrained_model(
offload_folder=None,
offload_state_dict=False,
dtype=None,
load_in_8bit=False,
):
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand Down