Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bitsandbytes #15622

Merged
merged 18 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@
from transformers import logging as transformers_logging

from .deepspeed import is_deepspeed_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
from .integrations import (
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
)
from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
Expand Down Expand Up @@ -638,6 +646,36 @@ def require_deepspeed(test_case):
return test_case


def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case


def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case


def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case


def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,15 @@ def create_optimizer(self):
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

for module in self.model.modules():
if isinstance(module, nn.Embedding):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
Expand Down Expand Up @@ -910,6 +919,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit

optimizer_cls = Adam8bit
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"


@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
DummyObject,
_LazyModule,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
is_datasets_available,
is_detectron2_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None


def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None


def is_faiss_available():
return _faiss_available

Expand Down
115 changes: 90 additions & 25 deletions tests/extended/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@
import re
import sys
import unittest
from typing import Tuple
from unittest.mock import patch

from parameterized import parameterized
from transformers.integrations import is_fairscale_available
from transformers import AutoModel
from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
Expand All @@ -36,7 +40,6 @@
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from transformers.utils import is_apex_available


bindir = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -49,28 +52,6 @@
MBART_TINY = "sshleifer/tiny-mbart"


# a candidate for testing_utils
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case


# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case


@require_torch
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(
Expand Down Expand Up @@ -193,7 +174,7 @@ def test_trainer_log_level_replica(self, experiment_id):
self.assertEqual(n_matches, data["n_matches"])

@slow
def test_run_seq2seq_slow(self):
def test_run_seq2seq(self):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
Expand All @@ -218,6 +199,88 @@ def test_run_seq2seq_slow(self):
assert "generated_predictions.txt" in contents
assert "predict_results.json" in contents

@slow
@require_bitsandbytes
def test_run_seq2seq_bnb(self):
from transformers.training_args import OptimizerNames

def train_and_return_metrics(optim: str) -> Tuple[int, float]:
from pathlib import Path

extra_args = (
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
"False --adafactor False --log_level debug"
)

output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
model_name=MARIAN_MODEL,
learning_rate=3e-4,
num_train_epochs=1,
distributed=True, # force run in a new process
extra_args_str=extra_args,
do_eval=False,
do_predict=False,
)

# Check metrics
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]

loss = logs[0]["train_loss"]
return gpu_peak_mem, gpu_alloc_mem, loss

gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)

gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb

gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb

gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb

# leave this for now if CI gets very different results
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")

self.assertGreater(
gpu_peak_mem_diff_percent,
10, # basically a huge difference - got ~30x on my desktop
"should use very little peak gpu memory with BNB, compared to without it"
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
)

self.assertGreater(
gpu_total_mem_diff_percent,
0.20, # could easily be 0.50, but let's stay on the safe side
"Using BNB should use less total GPU memory than without it"
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
)

self.assertEqual(
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
)

# Additionally let's test that the absolute gpu memory difference is larger or about the
# same as the expected saving coming from BNB (6 bytes per param)
model = AutoModel.from_pretrained(MARIAN_MODEL)
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
bnb_saved_bytes = total_numel * 6 # 324MB

self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
)

def run_trainer(
self,
eval_steps: int,
Expand Down Expand Up @@ -300,6 +363,8 @@ def run_trainer(
{self.examples_dir_str}/pytorch/translation/run_translation.py
""".split()
cmd = [sys.executable] + distributed_args + args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["run_translation.py"] + args
Expand Down
45 changes: 42 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer


Expand Down Expand Up @@ -1870,6 +1870,7 @@ def hp_name(trial):
},
),
]

if is_apex_available():
import apex

Expand All @@ -1881,6 +1882,17 @@ def hp_name(trial):
)
)

if is_bitsandbytes_available():
import bitsandbytes as bnb

optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)


@require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase):
Expand All @@ -1905,8 +1917,8 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs):

def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
Expand All @@ -1930,6 +1942,33 @@ def test_fused_adam_no_apex(self):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)

def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)


@require_torch
@require_wandb
Expand Down