From 3104036e7f1a3cd6e07a69d648c3597de32f72fe Mon Sep 17 00:00:00 2001 From: "Manuel R. Ciosici" Date: Tue, 19 Apr 2022 13:01:29 -0700 Subject: [PATCH] Add support for bitsandbytes (#15622) * Add initial BNB integration * fixup! Add initial BNB integration * Add bnb test decorator * Update Adamw8bit option name * Use the full bnb package name * Overide bnb for all embedding layers * Fix package name * Formatting * Remove unnecessary import * Update src/transformers/trainer.py Co-authored-by: Stas Bekman * Rename AdamwBNB optimizer option * Add training test checking that bnb memory utilization is lower * fix merge * fix merge; fix + extend new test * cleanup * expand bnb * move all require_* candidates to testing_utils.py Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- src/transformers/testing_utils.py | 40 ++++++++- src/transformers/trainer.py | 17 ++++ src/transformers/training_args.py | 1 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 4 + tests/extended/test_trainer_ext.py | 115 +++++++++++++++++++------ tests/trainer/test_trainer.py | 45 +++++++++- 7 files changed, 194 insertions(+), 29 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b60c7942097a14..36f56d2eeb29c6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -31,8 +31,16 @@ from transformers import logging as transformers_logging from .deepspeed import is_deepspeed_available -from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available +from .integrations import ( + is_fairscale_available, + is_optuna_available, + is_ray_available, + is_sigopt_available, + is_wandb_available, +) from .utils import ( + is_apex_available, + is_bitsandbytes_available, is_detectron2_available, is_faiss_available, is_flax_available, @@ -638,6 +646,36 @@ def require_deepspeed(test_case): return test_case +def require_fairscale(test_case): + """ + Decorator marking a test that requires fairscale + """ + if not is_fairscale_available(): + return unittest.skip("test requires fairscale")(test_case) + else: + return test_case + + +def require_apex(test_case): + """ + Decorator marking a test that requires apex + """ + if not is_apex_available(): + return unittest.skip("test requires apex")(test_case) + else: + return test_case + + +def require_bitsandbytes(test_case): + """ + Decorator for bits and bytes (bnb) dependency + """ + if not is_bitsandbytes_available(): + return unittest.skip("test requires bnb")(test_case) + else: + return test_case + + def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6ec32667dc7a84..9e61a36ecf4d81 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -867,6 +867,15 @@ def create_optimizer(self): ) else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + for module in self.model.modules(): + if isinstance(module, nn.Embedding): + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) @@ -917,6 +926,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim == OptimizerNames.ADAMW_BNB: + try: + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 68efcbd5cf1c5c..cc0a5ec835704d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum): ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_APEX_FUSED = "adamw_apex_fused" ADAFACTOR = "adafactor" + ADAMW_BNB = "adamw_bnb_8bit" @dataclass diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 45364fb8fd335f..6101a924f969a0 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -85,6 +85,7 @@ DummyObject, _LazyModule, is_apex_available, + is_bitsandbytes_available, is_coloredlogs_available, is_datasets_available, is_detectron2_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6207d0df7ceaa6..505ba94e0b193c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -400,6 +400,10 @@ def is_apex_available(): return importlib.util.find_spec("apex") is not None +def is_bitsandbytes_available(): + return importlib.util.find_spec("bitsandbytes") is not None + + def is_faiss_available(): return _faiss_available diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index f13753bdcfd01b..af8c5d4dd785de 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -17,10 +17,11 @@ import re import sys import unittest +from typing import Tuple from unittest.mock import patch from parameterized import parameterized -from transformers.integrations import is_fairscale_available +from transformers import AutoModel from transformers.testing_utils import ( CaptureStderr, ExtendSysPath, @@ -28,6 +29,9 @@ execute_subprocess_async, get_gpu_count, get_torch_dist_unique_port, + require_apex, + require_bitsandbytes, + require_fairscale, require_torch, require_torch_gpu, require_torch_multi_gpu, @@ -36,7 +40,6 @@ ) from transformers.trainer_callback import TrainerState from transformers.trainer_utils import set_seed -from transformers.utils import is_apex_available bindir = os.path.abspath(os.path.dirname(__file__)) @@ -49,28 +52,6 @@ MBART_TINY = "sshleifer/tiny-mbart" -# a candidate for testing_utils -def require_fairscale(test_case): - """ - Decorator marking a test that requires fairscale - """ - if not is_fairscale_available(): - return unittest.skip("test requires fairscale")(test_case) - else: - return test_case - - -# a candidate for testing_utils -def require_apex(test_case): - """ - Decorator marking a test that requires apex - """ - if not is_apex_available(): - return unittest.skip("test requires apex")(test_case) - else: - return test_case - - @require_torch class TestTrainerExt(TestCasePlus): def run_seq2seq_quick( @@ -193,7 +174,7 @@ def test_trainer_log_level_replica(self, experiment_id): self.assertEqual(n_matches, data["n_matches"]) @slow - def test_run_seq2seq_slow(self): + def test_run_seq2seq(self): output_dir = self.run_trainer( eval_steps=2, max_len=128, @@ -218,6 +199,88 @@ def test_run_seq2seq_slow(self): assert "generated_predictions.txt" in contents assert "predict_results.json" in contents + @slow + @require_bitsandbytes + def test_run_seq2seq_bnb(self): + from transformers.training_args import OptimizerNames + + def train_and_return_metrics(optim: str) -> Tuple[int, float]: + from pathlib import Path + + extra_args = ( + f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict " + "False --adafactor False --log_level debug" + ) + + output_dir = self.run_trainer( + eval_steps=2, + max_len=128, + model_name=MARIAN_MODEL, + learning_rate=3e-4, + num_train_epochs=1, + distributed=True, # force run in a new process + extra_args_str=extra_args, + do_eval=False, + do_predict=False, + ) + + # Check metrics + logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history + gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"] + gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"] + + loss = logs[0]["train_loss"] + return gpu_peak_mem, gpu_alloc_mem, loss + + gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value) + gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value) + + gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb + gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb + + gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig + gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb + + gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb + gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb + + # leave this for now if CI gets very different results + # print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" ) + # print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}") + # print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}") + # print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}") + # print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}") + + self.assertGreater( + gpu_peak_mem_diff_percent, + 10, # basically a huge difference - got ~30x on my desktop + "should use very little peak gpu memory with BNB, compared to without it" + f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}", + ) + + self.assertGreater( + gpu_total_mem_diff_percent, + 0.20, # could easily be 0.50, but let's stay on the safe side + "Using BNB should use less total GPU memory than without it" + f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}", + ) + + self.assertEqual( + loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}" + ) + + # Additionally let's test that the absolute gpu memory difference is larger or about the + # same as the expected saving coming from BNB (6 bytes per param) + model = AutoModel.from_pretrained(MARIAN_MODEL) + total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + bnb_saved_bytes = total_numel * 6 # 324MB + + self.assertGreater( + gpu_total_mem_diff_bytes, + bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less + f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}", + ) + def run_trainer( self, eval_steps: int, @@ -300,6 +363,8 @@ def run_trainer( {self.examples_dir_str}/pytorch/translation/run_translation.py """.split() cmd = [sys.executable] + distributed_args + args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) else: testargs = ["run_translation.py"] + args diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 81982bb53cd1da..1d80a85f0ef5a4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -65,7 +65,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.training_args import OptimizerNames -from transformers.utils import WEIGHTS_NAME, is_apex_available +from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available from transformers.utils.hp_naming import TrialShortNamer @@ -1870,6 +1870,7 @@ def hp_name(trial): }, ), ] + if is_apex_available(): import apex @@ -1881,6 +1882,17 @@ def hp_name(trial): ) ) + if is_bitsandbytes_available(): + import bitsandbytes as bnb + + optim_test_params.append( + ( + OptimizerNames.ADAMW_BNB, + bnb.optim.Adam8bit, + default_adam_kwargs, + ) + ) + @require_torch class TrainerOptimizerChoiceTest(unittest.TestCase): @@ -1905,8 +1917,8 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): def test_fused_adam(self): # Pretend that apex is installed and mock apex.optimizers.FusedAdam exists. - # Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a - # class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow + # Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the + # class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow # the test to run without requiring an apex installation. mock = Mock() modules = { @@ -1930,6 +1942,33 @@ def test_fused_adam_no_apex(self): with self.assertRaises(ValueError): Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_adam8bit(self): + # Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists. + # Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the + # class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow + # the test to run without requiring a bnb installation. + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + OptimizerNames.ADAMW_BNB, + default_adam_kwargs, + mock.optim.Adam8bit, + ) + + def test_bnb_adam8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if bnb is installed. + with patch.dict("sys.modules", {"bnb.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + @require_torch @require_wandb