Skip to content

Commit

Permalink
[Flax] Align GLUE training script with mlm training script (huggingfa…
Browse files Browse the repository at this point in the history
…ce#11778)

* speed up flax glue

* remove unnecessary line

* remove folder

* remove run in loop

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
  • Loading branch information
patrickvonplaten and Patrick von Platen committed May 21, 2021
1 parent 2239438 commit bd98716
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 28 deletions.
45 changes: 22 additions & 23 deletions examples/flax/text-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer.

In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).

In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing).

| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) |
| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) |
| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) |
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) |
| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) |
| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) |
| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) |
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) |
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) |

Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
Expand All @@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https

| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|-------|-----------|------------|------------|-----------------|
| CoLA | 1m 46s | 1m 26s | 3m 9s | 4m 6s |
| SST-2 | 5m 30s | 6m 28s | 22m 33s | 34m 37s |
| MRPC | 1m 32s | 1m 14s | 2m 20s | 2m 56s |
| STS-B | 1m 33s | 1m 12s | 2m 16s | 2m 48s |
| QQP | 24m 40s | 31m 48s | 1h 59m 41s | 2h 54m |
| MNLI | 26m 30s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
| QNLI | 8m | 9m 40s | 34m 40s | 49m 8s |
| RTE | 1m 21s | 55s | 1m 10s | 1m 16s |
| WNLI | 1m 12s | 48s | 39s | 36s |
| CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s |
| SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s |
| MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s |
| STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s |
| QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m |
| MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
| QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s |
| RTE | 1m 20s | 55s | 1m 10s | 1m 16s |
| WNLI | 1m 11s | 48s | 39s | 36s |
|-------|
| **TOTAL** | 1h 13m | 1h 28m | 5h 16m | 6h 37m |
| **COST*** | $9.60 | $29.10 | $13.06 | $16.41 |
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |


*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
Expand All @@ -106,4 +105,4 @@ the following tables:
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
V100 GPU). GPU experiments are ran without further optimizations besides JAX
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
10 changes: 5 additions & 5 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig


Expand Down Expand Up @@ -407,6 +407,7 @@ def write_metric(train_metrics, eval_metrics, train_time, step):

num_epochs = int(args.num_train_epochs)
rng = jax.random.PRNGKey(args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
Expand All @@ -424,6 +425,7 @@ def train_step(
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels")

def loss_fn(params):
Expand All @@ -436,7 +438,7 @@ def loss_fn(params):
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics
return new_state, metrics, new_dropout_rng

p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

Expand Down Expand Up @@ -467,9 +469,7 @@ def eval_step(state, batch):

# train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
rng, dropout_rng = jax.random.split(rng)
dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs)
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics)
train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
Expand Down

0 comments on commit bd98716

Please sign in to comment.