Skip to content

Commit

Permalink
Add support for bitsandbytes (#15622)
Browse files Browse the repository at this point in the history
* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
  • Loading branch information
3 people authored Apr 19, 2022
1 parent e6d23a4 commit 3104036
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 29 deletions.
40 changes: 39 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@
from transformers import logging as transformers_logging

from .deepspeed import is_deepspeed_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
from .integrations import (
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
)
from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
Expand Down Expand Up @@ -638,6 +646,36 @@ def require_deepspeed(test_case):
return test_case


def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case


def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case


def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case


def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,15 @@ def create_optimizer(self):
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

for module in self.model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
Expand Down Expand Up @@ -917,6 +926,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit

optimizer_cls = Adam8bit
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"


@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
DummyObject,
_LazyModule,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
is_datasets_available,
is_detectron2_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None


def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None


def is_faiss_available():
return _faiss_available

Expand Down
115 changes: 90 additions & 25 deletions tests/extended/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@
import re
import sys
import unittest
from typing import Tuple
from unittest.mock import patch

from parameterized import parameterized
from transformers.integrations import is_fairscale_available
from transformers import AutoModel
from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
Expand All @@ -36,7 +40,6 @@
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from transformers.utils import is_apex_available


bindir = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -49,28 +52,6 @@
MBART_TINY = "sshleifer/tiny-mbart"


# a candidate for testing_utils
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case


# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case


@require_torch
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(
Expand Down Expand Up @@ -193,7 +174,7 @@ def test_trainer_log_level_replica(self, experiment_id):
self.assertEqual(n_matches, data["n_matches"])

@slow
def test_run_seq2seq_slow(self):
def test_run_seq2seq(self):
output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
Expand All @@ -218,6 +199,88 @@ def test_run_seq2seq_slow(self):
assert "generated_predictions.txt" in contents
assert "predict_results.json" in contents

@slow
@require_bitsandbytes
def test_run_seq2seq_bnb(self):
from transformers.training_args import OptimizerNames

def train_and_return_metrics(optim: str) -> Tuple[int, float]:
from pathlib import Path

extra_args = (
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
"False --adafactor False --log_level debug"
)

output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
model_name=MARIAN_MODEL,
learning_rate=3e-4,
num_train_epochs=1,
distributed=True, # force run in a new process
extra_args_str=extra_args,
do_eval=False,
do_predict=False,
)

# Check metrics
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]

loss = logs[0]["train_loss"]
return gpu_peak_mem, gpu_alloc_mem, loss

gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)

gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb

gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb

gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb

# leave this for now if CI gets very different results
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")

self.assertGreater(
gpu_peak_mem_diff_percent,
10, # basically a huge difference - got ~30x on my desktop
"should use very little peak gpu memory with BNB, compared to without it"
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
)

self.assertGreater(
gpu_total_mem_diff_percent,
0.20, # could easily be 0.50, but let's stay on the safe side
"Using BNB should use less total GPU memory than without it"
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
)

self.assertEqual(
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
)

# Additionally let's test that the absolute gpu memory difference is larger or about the
# same as the expected saving coming from BNB (6 bytes per param)
model = AutoModel.from_pretrained(MARIAN_MODEL)
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
bnb_saved_bytes = total_numel * 6 # 324MB

self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
)

def run_trainer(
self,
eval_steps: int,
Expand Down Expand Up @@ -300,6 +363,8 @@ def run_trainer(
{self.examples_dir_str}/pytorch/translation/run_translation.py
""".split()
cmd = [sys.executable] + distributed_args + args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["run_translation.py"] + args
Expand Down
45 changes: 42 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer


Expand Down Expand Up @@ -1870,6 +1870,7 @@ def hp_name(trial):
},
),
]

if is_apex_available():
import apex

Expand All @@ -1881,6 +1882,17 @@ def hp_name(trial):
)
)

if is_bitsandbytes_available():
import bitsandbytes as bnb

optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)


@require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase):
Expand All @@ -1905,8 +1917,8 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):

def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
Expand All @@ -1930,6 +1942,33 @@ def test_fused_adam_no_apex(self):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)

def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)


@require_torch
@require_wandb
Expand Down

0 comments on commit 3104036

Please sign in to comment.