Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Clarify the ELMo readme #1167

Merged
merged 6 commits into from
May 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 18 additions & 32 deletions allennlp/commands/elmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@

from allennlp.common.tqdm import Tqdm
from allennlp.common.util import lazy_groups_of
from allennlp.data.dataset import Batch
from allennlp.data import Token, Vocabulary, Instance
from allennlp.data.fields import TextField
from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
from allennlp.nn.util import remove_sentence_boundaries
from allennlp.modules.elmo import _ElmoBiLm
from allennlp.modules.elmo import _ElmoBiLm, batch_to_ids
from allennlp.commands.subcommand import Subcommand

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand All @@ -67,6 +64,13 @@


class Elmo(Subcommand):
"""
Note that ELMo maintains an internal state dependent on previous batches.
As a result, ELMo will return differing results if the same sentence is
passed to the same ``Elmo`` instance multiple times.

See https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md for more details.
"""
def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
# pylint: disable=protected-access
description = '''Create word vectors using ELMo.'''
Expand Down Expand Up @@ -127,33 +131,6 @@ def __init__(self,

self.cuda_device = cuda_device

def batch_to_ids(self, batch: List[List[str]]) -> torch.Tensor:
"""
Converts a batch of tokenized sentences to a tensor representing the sentences with encoded characters
(len(batch), max sentence length, max word length).

Parameters
----------
batch : ``List[List[str]]``, required
A list of tokenized sentences.

Returns
-------
A tensor of padded character ids.
"""
instances = []
for sentence in batch:
tokens = [Token(token) for token in sentence]
field = TextField(tokens,
{'character_ids': self.indexer})
instance = Instance({"elmo": field})
instances.append(instance)

dataset = Batch(instances)
vocab = Vocabulary()
dataset.index_instances(vocab)
return dataset.as_tensor_dict()['elmo']['character_ids']

def batch_to_embeddings(self, batch: List[List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
Expand All @@ -166,7 +143,7 @@ def batch_to_embeddings(self, batch: List[List[str]]) -> Tuple[torch.Tensor, tor
A tuple of tensors, the first representing activations (batch_size, 3, num_timesteps, 1024) and
the second a mask (batch_size, num_timesteps).
"""
character_ids = self.batch_to_ids(batch)
character_ids = batch_to_ids(batch)
if self.cuda_device >= 0:
character_ids = character_ids.cuda(device=self.cuda_device)

Expand All @@ -188,6 +165,9 @@ def embed_sentence(self, sentence: List[str]) -> numpy.ndarray:
"""
Computes the ELMo embeddings for a single tokenized sentence.

Please note that ELMo has internal state and will give different results for the same input.
See the comment under the class definition.

Parameters
----------
sentence : ``List[str]``, required
Expand All @@ -204,6 +184,9 @@ def embed_batch(self, batch: List[List[str]]) -> List[numpy.ndarray]:
"""
Computes the ELMo embeddings for a batch of tokenized sentences.

Please note that ELMo has internal state and will give different results for the same input.
See the comment under the class definition.

Parameters
----------
batch : ``List[List[str]]``, required
Expand Down Expand Up @@ -237,6 +220,9 @@ def embed_sentences(self,
"""
Computes the ELMo embeddings for a iterable of sentences.

Please note that ELMo has internal state and will give different results for the same input.
See the comment under the class definition.

Parameters
----------
sentences : ``Iterable[List[str]]``, required
Expand Down
35 changes: 34 additions & 1 deletion allennlp/modules/elmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from allennlp.modules.highway import Highway
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.nn.util import remove_sentence_boundaries, add_sentence_boundary_token_ids
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.dataset import Batch
from allennlp.data import Token, Vocabulary, Instance
from allennlp.data.fields import TextField


logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -160,6 +164,35 @@ def from_params(cls, params: Params) -> 'Elmo':
requires_grad=requires_grad, do_layer_norm=do_layer_norm)


def batch_to_ids(batch: List[List[str]]) -> torch.Tensor:
"""
Converts a batch of tokenized sentences to a tensor representing the sentences with encoded characters
(len(batch), max sentence length, max word length).

Parameters
----------
batch : ``List[List[str]]``, required
A list of tokenized sentences.

Returns
-------
A tensor of padded character ids.
"""
instances = []
indexer = ELMoTokenCharactersIndexer()
for sentence in batch:
tokens = [Token(token) for token in sentence]
field = TextField(tokens,
{'character_ids': indexer})
instance = Instance({"elmo": field})
instances.append(instance)

dataset = Batch(instances)
vocab = Vocabulary()
dataset.index_instances(vocab)
return dataset.as_tensor_dict()['elmo']['character_ids']


class _ElmoCharacterEncoder(torch.nn.Module):
"""
Compute context sensitive token representation using pretrained biLM.
Expand Down
54 changes: 36 additions & 18 deletions tutorials/how_to/elmo.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,39 @@ For more details, see `allennlp elmo -h`.

## Using ELMo programmatically

If you need to include ELMo at multiple layers in a task model or you have other advanced use cases, you will need to create ELMo vectors
programatically. This is easily done with the ElmoEmbedder class [(API doc)](https://github.com/allenai/allennlp/tree/master/allennlp/commands/elmo.py).
If you need to include ELMo at multiple layers in a task model or you have other advanced use cases, you will need to create ELMo vectors programatically.
This is easily done with the `Elmo` class [(API doc)](https://github.com/allenai/allennlp/blob/master/allennlp/modules/elmo.py#L27), which provides a mechanism to compute the weighted ELMo representations (Equation (1) in the paper).

This is a `torch.nn.Module` subclass that computes any number of ELMo
representations and introduces trainable scalar weights for each.
For example, this code snippet computes two layers of representations
(as in the SNLI and SQuAD models from our paper):

```python
from allennlp.commands.elmo import ElmoEmbedder
from allennlp.modules.elmo import Elmo, batch_to_ids

ee = ElmoEmbedder()
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

embeddings = ee.embed_sentence("Bitcoin alone has a sixty percent share of global search .".split())
elmo = Elmo(options_file, weight_file, 2, dropout=0)

# embeddings has shape (3, 11, 1024)
# 3 - the number of ELMo vectors.
# 11 - the number of words in the input sentence
# use batch_to_ids to convert sentences to character ids
sentences = [['First', 'sentence', '.'], ['Another', '.']]
character_ids = batch_to_ids(sentences)

embeddings = elmo(character_ids)

# embeddings['elmo_representations'] is length two list of tensors.
# Each element contains one layer of ELMo representations with shape
# (2, 3, 1024).
# 2 - the batch size
# 3 - the sequence length of the batch
# 1024 - the length of each ELMo vector
```

For larger datasets, batching the sentences by using the `batch_to_embeddings` method
will speed up the computation significantly.
If you are not training a pytorch model, and just want numpy arrays as output
then use `allennlp.commands.elmo.ElmoEmbedder`.

Also note that `ElmoEmbedder` is a utility class that bundles together several
tasks related to computing ELMo representations including mapping strings to character ids and
running the pre-trained biLM. It is not designed to be used when training a model and
is not a subclass of `torch.nn.Module`. To train a model with ELMo, we recommend using
the `allennlp.modules.elmo.Elmo` class, which does subclass `torch.nn.Module` and implements
`forward`.

## Using ELMo with existing `allennlp` models

Expand Down Expand Up @@ -126,5 +133,16 @@ general guidelines for an initial training run.
* Add some dropout (0.5 is a good default value), either in the `Elmo` class directly, or in the next layer of your network. If the next layer of the network includes dropout then set `dropout=0` when constructing the `Elmo` class.
* Add a small amount of L2 regularization to the scalar weighting parameters (`lambda=0.001` in the paper). These are the parameters named `scalar_mix_L.scalar_parameters.X` where `X=[0, 1, 2]` indexes the biLM layer and `L` indexes the number of ELMo representations included in the downstream model. Often performance is slightly higher for larger datasets without regularizing these parameters, but it can sometimes cause training to be unstable.

Finally, we have found that including pre-trained GloVe or other word vectors in addition to ELMo
provides little to no improvement over just using ELMo and slows down training.
Finally, we have found that in some cases including pre-trained GloVe or other word vectors in addition to ELMo provides little to no improvement over just using ELMo and slows down training. However, we recommend experimenting with your dataset and model architecture for best results.

## Notes on statefulness and non-determinism

The pre-trained biLM used to compute ELMo representations was trained without resetting the internal LSTM states between sentences.
Accordingly, the re-implementation in allennlp is stateful, and carries the LSTM states forward from batch to batch.
Since the biLM was trained on randomly shuffled sentences padded with special `<S>` and `</S>` tokens, it will reset the internal states to its own internal representation of sentence break when seeing these tokens.

There are a few practical implications of this:

* Due to the statefulness, the ELMo vectors are not deterministic and running the same batch multiple times will result in slightly different embeddings.
* After loading the pre-trained model, the first few batches will be negatively impacted until the biLM can reset its internal states. You may want to run a few batches through the model to warm up the states before making predictions (although we have not worried about this issue in practice).
* It is important to always add the `<S>` and `</S>` tokens to each sentence. The `allennlp` code handles this behind the scenes, but if you are handing padding and indexing in a different manner then take care to ensure this is handled appropriately.