Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master' into patch-5
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB authored Aug 7, 2020
2 parents 39c5e4f + 660fdaf commit f94353d
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 2 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed how truncation was handled with `PretrainedTransformerTokenizer`.
Previously, if `max_length` was set to `None`, the tokenizer would still do truncation if the
transformer model had a default max length in its config.
Also, when `max_length` was set to a non-`None` value, several warnings would appear
for certain transformer models around the use of the `truncation` parameter.

## [v1.1.0rc2](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc2) - 2020-07-31

### Changed
Expand All @@ -26,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added the option to specify `requires_grad: false` within an optimizer's parameter groups.
- Added the `file-friendly-logging` flag back to the `train` command. Also added this flag to the `predict`, `evaluate`, and `find-learning-rate` commands.
- Added an `EpochCallback` to track current epoch as a model class member.

### Removed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def tokenize(self, text: str) -> List[Token]:
add_special_tokens=False,
max_length=self._max_length,
stride=self._stride,
truncation_strategy=self._truncation_strategy,
truncation=self._truncation_strategy if self._max_length is not None else False,
return_tensors=None,
return_offsets_mapping=self.tokenizer.is_fast,
return_attention_mask=False,
Expand Down
8 changes: 7 additions & 1 deletion allennlp/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.tensorboard_writer import TensorboardWriter
from allennlp.training.no_op_trainer import NoOpTrainer
from allennlp.training.trainer import Trainer, GradientDescentTrainer, BatchCallback, EpochCallback
from allennlp.training.trainer import (
Trainer,
GradientDescentTrainer,
BatchCallback,
EpochCallback,
TrackEpochCallback,
)
23 changes: 23 additions & 0 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,29 @@ def __call__(
EpochCallback.register("null")(EpochCallback)


@EpochCallback.register("track_epoch_callback")
class TrackEpochCallback:
"""
A callback that you can pass to the `GradientDescentTrainer` to access the current epoch number
in your model during training. This callback sets `model.epoch`, which can be read inside of
`model.forward()`. Since the EpochCallback passes `epoch=-1`
at the start of the training, we set `model.epoch = epoch + 1` which now denotes the number of
completed epochs at a given training state.
"""

def __init__(self):
super().__init__()

def __call__(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
trainer.model.epoch = epoch + 1


@Trainer.register("gradient_descent", constructor="from_partial_objects")
class GradientDescentTrainer(Trainer):
"""
Expand Down
18 changes: 18 additions & 0 deletions tests/data/tokenizers/pretrained_transformer_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,24 @@ def test_token_idx_bert_cased(self):
idxs = [t.idx for t in tokenized]
assert idxs == expected_idxs

def test_max_length(self):
tokenizer = PretrainedTransformerTokenizer(
"bert-base-cased", max_length=10, add_special_tokens=False
)
tokens = tokenizer.tokenize(
"hi there, this should be at least 10 tokens, but some will be truncated"
)
assert len(tokens) == 10

def test_no_max_length(self):
tokenizer = PretrainedTransformerTokenizer(
"bert-base-cased", max_length=None, add_special_tokens=False
)
# Even though the bert model has a max input length of 512, when we tokenize
# with `max_length = None`, we should not get any truncation.
tokens = tokenizer.tokenize(" ".join(["a"] * 550))
assert len(tokens) == 550

def test_token_idx_roberta(self):
sentence = "A, naïve <mask> AllenNLP sentence."
expected_tokens = [
Expand Down
14 changes: 14 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TensorboardWriter,
BatchCallback,
EpochCallback,
TrackEpochCallback,
)
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler
Expand Down Expand Up @@ -986,6 +987,19 @@ def __call__(
expected_calls = [epoch for epoch in range(-1, 4)]
assert trainer.epoch_callback_calls == expected_calls

def test_track_epoch_callback(self):
num_epochs = 4
trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=num_epochs,
validation_data_loader=self.validation_data_loader,
epoch_callbacks=[TrackEpochCallback()],
)
trainer.train()
assert trainer.model.epoch == num_epochs

def test_total_loss_is_average_of_batch_loss(self):

batches_per_epoch = 3
Expand Down

0 comments on commit f94353d

Please sign in to comment.