Skip to content

Commit

Permalink
Improvements to Flax finetuning script (huggingface#11727)
Browse files Browse the repository at this point in the history
* Add Cloud details to README

* Flax script and readme updates

* Some simplifications of Flax script
  • Loading branch information
marcvanzee committed May 17, 2021
1 parent 86d5fb0 commit 726e953
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
20 changes: 10 additions & 10 deletions examples/flax/text-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer.

In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing).
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).


| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) |
| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) |
| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) |
| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) |
| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) |
| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) |
| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) |
| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) |
| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) |
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |

Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
Expand Down
31 changes: 14 additions & 17 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def parse_args():
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.")
parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.")
args = parser.parse_args()

# Sanity checks
Expand All @@ -148,6 +148,7 @@ def create_train_state(
learning_rate_fn: Callable[[int], float],
is_regression: bool,
num_labels: int,
weight_decay: float,
) -> train_state.TrainState:
"""Create initial training state."""

Expand All @@ -166,8 +167,8 @@ class TrainState(train_state.TrainState):
loss_fn: Callable = struct.field(pytree_node=False)

# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
def adamw(weight_decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay)
def adamw(decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay)

def traverse(fn):
def mask(data):
Expand All @@ -183,7 +184,7 @@ def mask(data):

tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))),
optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
)

if is_regression:
Expand Down Expand Up @@ -414,7 +415,9 @@ def write_metric(train_metrics, eval_metrics, train_time, step):
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
)

state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels)
state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
)

# define step functions
def train_step(
Expand All @@ -426,10 +429,10 @@ def train_step(
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss, logits
return loss

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grad = grad_fn(state.params)
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
Expand Down Expand Up @@ -460,18 +463,18 @@ def eval_step(state, batch):

train_start = time.time()
train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
rng, input_rng = jax.random.split(rng)

# train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
rng, dropout_rng = jax.random.split(rng)
dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics)
train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")

logger.info(" Evaluating...")
rng, input_rng = jax.random.split(rng)

# evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
Expand All @@ -484,20 +487,14 @@ def eval_step(state, batch):

# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# put weights on single device
state = unreplicate(state)

# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()}

labels = batch.pop("labels")
predictions = eval_step(state, batch)
predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels)

# make sure weights are replicated on each device
state = replicate(state)

eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")

Expand Down

0 comments on commit 726e953

Please sign in to comment.