Skip to content

Commit

Permalink
[Benchmark] Add benchmarks for TF Training (huggingface#5594)
Browse files Browse the repository at this point in the history
* tf_train

* adapt timing for tpu

* fix timing

* fix timing

* fix timing

* fix timing

* update notebook

* add tests
  • Loading branch information
patrickvonplaten authored Jul 8, 2020
1 parent cfbb982 commit f82a2a5
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 15 deletions.
6 changes: 3 additions & 3 deletions notebooks/05-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,16 @@
":-- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n",
"**Speed - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✔ |\n",
"**Memory - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✘ |\n",
"**Speed - Train** | | ✘ | | ✘ | ✘ | ✘ | |\n",
"**Memory - Train** | | ✘ | | ✘ | ✘ | ✘ | ✘ |\n",
"**Speed - Train** | | ✘ | | ✘ | ✘ | ✘ | |\n",
"**Memory - Train** | | ✘ | | ✘ | ✘ | ✘ | ✘ |\n",
"\n",
"* *eager execution* means that the function is run in the eager execution environment of TensorFlow 2, see [here](https://www.tensorflow.org/guide/eager).\n",
"\n",
"* *XLA* stands for TensorFlow's Accelerated Linear Algebra (XLA) compiler, see [here](https://www.tensorflow.org/xla)\n",
"\n",
"* *FP16* stands for TensorFlow's mixed-precision package and is analogous to PyTorch's FP16 feature, see [here](https://www.tensorflow.org/guide/mixed_precision).\n",
"\n",
"***Note***: In ~1,2 weeks it will also be possible to benchmark training in TensorFlow.\n",
"***Note***: Benchmark training in TensorFlow is not included in v3.0.2, but available in master.\n",
"\n",
"\n",
"This notebook will show the user how to use `PyTorchBenchmark` and `TensorFlowBenchmark` for two different scenarios:\n",
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
else:
train_model = model

model.eval()
model.train()
model.to(self.args.device)

# encoder-decoder has vocab size saved differently
Expand All @@ -175,12 +175,12 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
def compute_loss_and_backprob_encoder():
loss = train_model(input_ids, labels=input_ids)[0]
loss.backward()
train_model.zero_grad()
return loss

def compute_loss_and_backprob_encoder_decoder():
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss.backward()
train_model.zero_grad()
return loss

_train = (
compute_loss_and_backprob_encoder_decoder
Expand Down
75 changes: 66 additions & 9 deletions src/transformers/benchmark/benchmark_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from functools import wraps
from typing import Callable, Optional

from transformers import TF_MODEL_MAPPING, PretrainedConfig, is_py3nvml_available, is_tf_available
from transformers import (
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
PretrainedConfig,
is_py3nvml_available,
is_tf_available,
)

from .benchmark_utils import (
Benchmark,
Expand Down Expand Up @@ -92,10 +98,11 @@ def _inference_speed(self, model_name: str, batch_size: int, sequence_length: in
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference)

def _train_speed(self, model_name, batch_size, sequence_length):
raise NotImplementedError(
"Training is currently not really implemented." "Wait for TFTrainer to support CLM and MLM."
)
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_speed(_train)

def _inference_memory(
self, model_name: str, batch_size: int, sequence_length: int
Expand All @@ -108,10 +115,16 @@ def _inference_memory(
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference)

def _train_memory(self, model_name, batch_size, sequence_length):
raise NotImplementedError(
"Training is currently not really implemented. Wait for TFTrainer to support CLM and MLM."
)
def _train_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."

_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_memory(_train)

def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
Expand Down Expand Up @@ -149,6 +162,50 @@ def encoder_forward():

return _inference

def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]

assert (
self.args.eager_mode is False
), "Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`."

if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")

has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)

# encoder-decoder has vocab size saved differently
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)

@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_decoder_train():
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients

@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_train():
loss = model(input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients

_train = encoder_decoder_train if config.is_encoder_decoder else encoder_train

return _train

def _measure_speed(self, func) -> float:
with self.args.strategy.scope():
try:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_benchmark_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,37 @@ def test_inference_with_configs_graph(self):
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)

def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)

def test_train_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorFlowBenchmark(benchmark_args, [config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)

def test_inference_encoder_decoder_with_configs(self):
MODEL_ID = "patrickvonplaten/t5-tiny-random"
config = AutoConfig.from_pretrained(MODEL_ID)
Expand Down

0 comments on commit f82a2a5

Please sign in to comment.