Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bitsandbytes #15622

Merged
merged 18 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None


def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None


def is_faiss_available():
return _faiss_available

Expand Down
19 changes: 19 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,17 @@ def create_optimizer(self):
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
from torch.nn import Embedding
stas00 marked this conversation as resolved.
Show resolved Hide resolved

import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

for module in self.model.modules():
if isinstance(module, Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.info(f"Registering bitsandbytes override for {module}")
stas00 marked this conversation as resolved.
Show resolved Hide resolved
manuelciosici marked this conversation as resolved.
Show resolved Hide resolved

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
Expand Down Expand Up @@ -897,6 +908,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit

optimizer_cls = Adam8bit
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb"
stas00 marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down
13 changes: 12 additions & 1 deletion tests/extended/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from unittest.mock import patch

from parameterized import parameterized
from transformers.file_utils import is_apex_available
from transformers.file_utils import is_apex_available, is_bnb_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
CaptureStderr,
Expand Down Expand Up @@ -71,6 +71,17 @@ def require_apex(test_case):
return test_case


# a candidate for testing_utils
stas00 marked this conversation as resolved.
Show resolved Hide resolved
def require_bnb(test_case):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bnb_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case


@require_torch
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(
Expand Down
43 changes: 40 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME, is_apex_available
from transformers.file_utils import WEIGHTS_NAME, is_apex_available, is_bnb_available
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
Expand Down Expand Up @@ -1762,6 +1762,16 @@ def hp_name(trial):
default_adam_kwargs,
)
)
if is_bnb_available():
import bitsandbytes as bnb

optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)


@require_torch
Expand All @@ -1787,8 +1797,8 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):

def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
Expand All @@ -1812,6 +1822,33 @@ def test_fused_adam_no_apex(self):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)

def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)


@require_torch
@require_wandb
Expand Down