Skip to content

Commit

Permalink
Generate: use GenerationConfig as the basis for .generate() param…
Browse files Browse the repository at this point in the history
…etrization (huggingface#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored and venkat-natchi committed Jan 22, 2023
1 parent 07eb815 commit 9c842b0
Show file tree
Hide file tree
Showing 6 changed files with 692 additions and 731 deletions.
69 changes: 68 additions & 1 deletion docs/source/en/main_classes/text_generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].

<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
of the generation method.

All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.

```python
from transformers import AutoModelForCausalLM, GenerationConfig

model = AutoModelForCausalLM.from_pretrained("my_account/my_model")

# Inspect the default generation configuration
print(model.generation_config)

# Set a new default generation configuration
generation_config = GenerationConfig(
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
```

<Tip>

If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
into resource limitations. Make sure you double-check the defaults in the documentation.

</Tip>

You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
other for summarization with beam search).

```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

translation_generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
decoder_start_token_id=0,
eos_token_id=model.config.eos_token_id,
pad_token=model.config.pad_token_id,
)
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
# config as follows
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)

# You could then use the named generation config file to parameterize generation
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Les fichiers de configuration sont faciles à utiliser !']
```

Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
to parameterize it.


## GenerationConfig

[[autodoc]] generation.GenerationConfig
- from_pretrained
- from_model_config
- save_pretrained

## GenerationMixin
Expand Down
110 changes: 94 additions & 16 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Dict, Optional, Union

from .. import __version__
from ..configuration_utils import PretrainedConfig
from ..utils import (
GENERATION_CONFIG_NAME,
PushToHubMixin,
Expand All @@ -36,7 +37,23 @@

class GenerationConfig(PushToHubMixin):
r"""
Class that holds a configuration for a generation task.
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`.
<Tip>
Expand All @@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin):
</Tip>
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Arg:
> Parameters that control the length of the output
Expand Down Expand Up @@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
> Parameters for manipulation of the model output logits
Expand Down Expand Up @@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin):
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
Expand Down Expand Up @@ -191,6 +214,7 @@ def __init__(self, **kwargs):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.use_cache = kwargs.pop("use_cache", True)

# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
Expand All @@ -202,7 +226,9 @@ def __init__(self, **kwargs):
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_word_ids = kwargs.pop("force_word_ids", None)
self.force_words_ids = kwargs.pop("force_words_ids", None)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
Expand Down Expand Up @@ -230,12 +256,20 @@ def __init__(self, **kwargs):
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface.
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)

def __eq__(self, other):
return self.__dict__ == other.__dict__
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
Expand Down Expand Up @@ -484,18 +518,11 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)

to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
unused_kwargs = config.update(**kwargs)

logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, kwargs
return config, unused_kwargs
else:
return config

Expand Down Expand Up @@ -568,3 +595,54 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool =
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))

@classmethod
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Args:
model_config (`PretrainedConfig`):
The model config that will be used to instantiate the generation config.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])

config._from_model_config = True
return config

def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
Loading

0 comments on commit 9c842b0

Please sign in to comment.