From 9c842b0928c3ed39e1b016953ffb385683668d34 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Dec 2022 18:27:20 +0000 Subject: [PATCH] Generate: use `GenerationConfig` as the basis for `.generate()` parametrization (#20388) * generate from config mvp * fix failing tests * max_time test * Load default gen config at model load time; Update docs * further documentation; add tests * adapt rag to the new structure * handle models not instantiated with from_pretained (like in tests) * better default generation config * add can_generate fn * handle legacy use case of ad hoc model config changes * initialize gen config from config in individual methods, if gen config is none * fix _get_decoder_start_token_id when called outside GenerationMixin * correct model config load order (set attr > model config > decoder config) * update rag to match latest changes * Apply suggestions from code review Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * load gen config from model config in model.from_pretrained * fix can_generate fn * handle generate calls without a previous from_pretrained (e.g. tests) * add legacy behavior (and a warning) * lower logger severity Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../en/main_classes/text_generation.mdx | 69 +- .../generation/configuration_utils.py | 110 +- src/transformers/generation/utils.py | 977 ++++++++---------- src/transformers/modeling_utils.py | 38 +- src/transformers/models/rag/modeling_rag.py | 196 +--- tests/generation/test_configuration_utils.py | 33 +- 6 files changed, 692 insertions(+), 731 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 78bef8bd5a2537..1d00406ac1e584 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`]. - +Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`] +class instance. Please refer to this class for the complete list of generation parameters, which control the behavior +of the generation method. + +All models have a default generation configuration that will be used if you don't provide one. If you have a loaded +model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd +like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and +store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty. + +```python +from transformers import AutoModelForCausalLM, GenerationConfig + +model = AutoModelForCausalLM.from_pretrained("my_account/my_model") + +# Inspect the default generation configuration +print(model.generation_config) + +# Set a new default generation configuration +generation_config = GenerationConfig( + max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id +) +generation_config.save_pretrained("my_account/my_model", push_to_hub=True) +``` + + + +If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that +default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running +into resource limitations. Make sure you double-check the defaults in the documentation. + + + +You can also store several generation parametrizations in a single directory, making use of the `config_file_name` +argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to +store several generation configurations for a single model (e.g. one for creative text generation with sampling, and +other for summarization with beam search). + +```python +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig + +tokenizer = AutoTokenizer.from_pretrained("t5-small") +model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + +translation_generation_config = GenerationConfig( + num_beams=4, + early_stopping=True, + decoder_start_token_id=0, + eos_token_id=model.config.eos_token_id, + pad_token=model.config.pad_token_id, +) +# If you were working on a model for which your had the right Hub permissions, you could store a named generation +# config as follows +translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True) + +# You could then use the named generation config file to parameterize generation +generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json") +inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt") +outputs = model.generate(**inputs, generation_config=generation_config) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +# ['Les fichiers de configuration sont faciles à utiliser !'] +``` + +Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you +wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each +framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies +to parameterize it. + ## GenerationConfig [[autodoc]] generation.GenerationConfig - from_pretrained + - from_model_config - save_pretrained ## GenerationMixin diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 07a97c7f2522f8..a477ebe4203c43 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Union from .. import __version__ +from ..configuration_utils import PretrainedConfig from ..utils import ( GENERATION_CONFIG_NAME, PushToHubMixin, @@ -36,7 +37,23 @@ class GenerationConfig(PushToHubMixin): r""" - Class that holds a configuration for a generation task. + Class that holds a configuration for a generation task. A `generate` call supports the following generation methods + for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` + and `top_k>1` + - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if + `constraints!=None` or `force_words_ids!=None`. @@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin): + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + Arg: > Parameters that control the length of the output @@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin): [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. penalty_alpha (`float`, *optional*): The values balance the model confidence and the degeneration penalty in contrastive search decoding. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. > Parameters for manipulation of the model output logits @@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin): words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. renormalize_logits (`bool`, *optional*, defaults to `False`): Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use of + certain tokens as defined by `Constraint` objects, in the most sensible way possible. forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`): The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target @@ -191,6 +214,7 @@ def __init__(self, **kwargs): self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.use_cache = kwargs.pop("use_cache", True) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -202,7 +226,9 @@ def __init__(self, **kwargs): self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) - self.force_word_ids = kwargs.pop("force_word_ids", None) + self.force_words_ids = kwargs.pop("force_words_ids", None) + self.renormalize_logits = kwargs.pop("renormalize_logits", False) + self.constraints = kwargs.pop("constraints", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) @@ -230,12 +256,20 @@ def __init__(self, **kwargs): # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) - # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface. + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) def __eq__(self, other): - return self.__dict__ == other.__dict__ + self_dict = self.__dict__.copy() + other_dict = other.__dict__.copy() + # ignore metadata + for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"): + self_dict.pop(metadata_field, None) + other_dict.pop(metadata_field, None) + return self_dict == other_dict def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @@ -484,18 +518,11 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": kwargs["_commit_hash"] = config_dict["_commit_hash"] config = cls(**config_dict) - - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) + unused_kwargs = config.update(**kwargs) logger.info(f"Generate config {config}") if return_unused_kwargs: - return config, kwargs + return config, unused_kwargs else: return config @@ -568,3 +595,54 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string(use_diff=use_diff)) + + @classmethod + def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy + [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`]. + + Args: + model_config (`PretrainedConfig`): + The model config that will be used to instantiate the generation config. + + Returns: + [`GenerationConfig`]: The configuration object instantiated from those parameters. + """ + config_dict = model_config.to_dict() + config = cls.from_dict(config_dict, return_unused_kwargs=False) + + # Special case: some models have generation attributes set in the decoder. Use them if still unset in the + # generation config. + for decoder_name in ("decoder", "generator"): + if decoder_name in config_dict: + default_generation_config = GenerationConfig() + decoder_config = config_dict[decoder_name] + for attr in config.to_dict().keys(): + if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): + setattr(config, attr, decoder_config[attr]) + + config._from_model_config = True + return config + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3d945b2be37a74..03ad4a25a1d944 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -33,8 +34,9 @@ ) from ..pytorch_utils import torch_int_div from ..utils import ModelOutput, logging -from .beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint +from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, @@ -478,6 +480,11 @@ class GenerationMixin: `constraints!=None` or `force_words_ids!=None`. """ + def prepare_inputs_for_generation(self, *args, **kwargs): + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, @@ -620,26 +627,16 @@ def _prepare_decoder_input_ids_for_generation( def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): - return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @@ -722,166 +719,149 @@ def _reorder_cache(self, past, beam_idx): def _get_logits_warper( self, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - temperature: Optional[float] = None, - num_beams: Optional[int] = None, - renormalize_logits: Optional[bool] = None, + generation_config: GenerationConfig, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ - # init warp parameters - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - typical_p = typical_p if typical_p is not None else self.config.typical_p - temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if temperature is not None and temperature != 1.0: - warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) - if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) - if typical_p is not None and typical_p < 1.0: - warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append( + TopKLogitsWarper( + top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append( + TopPLogitsWarper( + top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper( + mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) # `LogitNormalization` should always be the last logit processor, when present - if renormalize_logits is True: + if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) return warpers def _get_logits_processor( self, - repetition_penalty: float, - no_repeat_ngram_size: int, - encoder_no_repeat_ngram_size: int, + generation_config: GenerationConfig, input_ids_seq_length: int, encoder_input_ids: torch.LongTensor, - bad_words_ids: List[List[int]], - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - num_beams: int, - num_beam_groups: int, - diversity_penalty: float, - remove_invalid_values: bool, - exponential_decay_length_penalty: Tuple, logits_processor: Optional[LogitsProcessorList], - renormalize_logits: Optional[bool], - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[List[int]]] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] instances used to modify the scores of the language model head. """ - processors = LogitsProcessorList() - - # init warp parameters - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - encoder_no_repeat_ngram_size = ( - encoder_no_repeat_ngram_size - if encoder_no_repeat_ngram_size is not None - else self.config.encoder_no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty - forced_bos_token_id = ( - forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id - ) - remove_invalid_values = ( - remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) - suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens - begin_suppress_tokens = ( - begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens - ) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids # instantiate processors list + processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if diversity_penalty is not None and diversity_penalty > 0.0: + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( - diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + diversity_penalty=generation_config.diversity_penalty, + num_beams=generation_config.num_beams, + num_beam_groups=generation_config.num_beam_groups, ) ) - if repetition_penalty is not None and repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) - if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if ( + generation_config.encoder_no_repeat_ngram_size is not None + and generation_config.encoder_no_repeat_ngram_size > 0 + ): if self.config.is_encoder_decoder: - processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) + processors.append( + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, encoder_input_ids + ) + ) else: raise ValueError( "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" ) - if bad_words_ids is not None: - processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > 0: - processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 + ): + processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) if prefix_allowed_tokens_fn is not None: - processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) - if forced_bos_token_id is not None: - processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) - if forced_eos_token_id is not None: - processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) - if remove_invalid_values is True: + processors.append( + PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups + ) + ) + if generation_config.forced_bos_token_id is not None: + processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) + if generation_config.remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) - if exponential_decay_length_penalty is not None: + if generation_config.exponential_decay_length_penalty is not None: processors.append( - ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) + ExponentialDecayLengthPenalty( + generation_config.exponential_decay_length_penalty, + generation_config.eos_token_id, + generation_config.input_ids_seq_length, + ) ) - if suppress_tokens is not None: - processors.append(SuppressTokensLogitsProcessor(suppress_tokens)) - if begin_suppress_tokens is not None: + if generation_config.suppress_tokens is not None: + processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length - begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 - if forced_decoder_ids is not None: - begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced - processors.append(SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) - if forced_decoder_ids is not None: - processors.append(ForceTokensLogitsProcessor(forced_decoder_ids)) + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + if generation_config.forced_decoder_ids is not None: + # generation starts after the last token that is forced + begin_index += generation_config.forced_decoder_ids[-1][0] + processors.append( + SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + ) + if generation_config.forced_decoder_ids is not None: + processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present - if renormalize_logits is True: + if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) return processors def _get_stopping_criteria( - self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() - if max_length is not None: - criteria.append(MaxLengthCriteria(max_length=max_length)) - if max_time is not None: - criteria.append(MaxTimeCriteria(max_time=max_time)) + if generation_config.max_length is not None: + criteria.append(MaxLengthCriteria(max_length=generation_config.max_length)) + if generation_config.max_time is not None: + criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -951,7 +931,7 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ - if not hasattr(self, "prepare_inputs_for_generation"): + if not self.can_generate(): generate_compatible_mappings = [ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -999,81 +979,27 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def generate( self, inputs: Optional[torch.Tensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - do_sample: Optional[bool] = None, - early_stopping: Optional[bool] = None, - num_beams: Optional[int] = None, - temperature: Optional[float] = None, - penalty_alpha: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[Iterable[int]] = None, - force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - num_return_sequences: Optional[int] = None, - max_time: Optional[float] = None, - max_new_tokens: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, - use_cache: Optional[bool] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, - renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - constraints: Optional[List[Constraint]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = False, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[List[int]]] = None, - **model_kwargs, + **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head. The method supports the following - generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - - - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False`. - - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` - and `top_k>1` - - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - `do_sample=True`. - - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - `do_sample=False`. - - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if - `num_beams>1` and `do_sample=True`. - - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if - `num_beams>1` and `num_beam_groups>1`. - - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - `constraints!=None` or `force_words_ids!=None`. + Generates sequences of token ids for models with a language modeling head. - Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as - defined in the model's config (`config.json`) which in turn defaults to the - [`~modeling_utils.PretrainedConfig`] of the model. + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - + For a complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). - Most of these parameters are explained in more detail in [this blog - post](https://huggingface.co/blog/how-to-generate). + Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): @@ -1081,81 +1007,21 @@ def generate( method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. - max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in - the prompt. - max_new_tokens (`int`, *optional*): - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value): - The minimum length of the sequence to be generated. - do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value): - Whether or not to use sampling ; use greedy decoding otherwise. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. - num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value): - Number of beams for beam search. 1 means no beam search. - temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value): - The value used to module the next token probabilities. - penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value): - The values balance the model confidence and the degeneration penalty in contrastive search decoding. - top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): - If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to - `top_p` or higher are kept for generation. - typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value): - The amount of probability mass from the original distribution to be considered in typical decoding. If - set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. - repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`): - The id of the *padding* token. - bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value): - If set to int > 0, all ngrams of that size can only occur once. - encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value): - If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the - `decoder_input_ids`. - bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`): - List of token ids that are not allowed to be generated. In order to get the token ids of the words that - should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, - add_special_tokens=False).input_ids`. - force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): - List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple - list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, - this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), - where one can allow different forms of each word. - num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value): - The number of independently computed returned sequences for each element in the batch. - max_time(`float`, *optional*): - The maximum amount of time you allow the computation to run for in seconds. generation will still - finish the current pass after allocated time has been passed. - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens - that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape - as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of - beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is - enabled. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and @@ -1163,60 +1029,12 @@ def generate( on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. This feature is intended for advanced users. - renormalize_logits (`bool`, *optional*, defaults to `False`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the - custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - model's config. If a stopping criteria is passed that is already created with the arguments or a - model's config an error is thrown. This feature is intended for advanced users. - constraints (`List[Constraint]`, *optional*): - Custom constraints that can be added to the generation to ensure that the output will contain the use - of certain tokens as defined by `Constraint` objects, in the most sensible way possible. - output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`): - The id of the token to force as the last generated token when `max_length` is reached. - remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): - Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to - crash. Note that using `remove_invalid_values` can slow down generation. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`): - This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been - generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates - where penalty starts and `decay_factor` represents the factor of exponential decay - suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`): - A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set - their log probs to `-inf` so that they are not sampled. - begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): - A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` - logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of pairs of integers which indicates a mapping from generation indices to token indices that - will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always - be a token of index 123. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model - is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs - should be prefixed with *decoder_*. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` @@ -1240,7 +1058,7 @@ def generate( Examples: - Greedy Decoding: + Greedy decoding, using the default generation configuration and ad hoc modifications: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM @@ -1251,16 +1069,16 @@ def generate( >>> prompt = "Today I believe we can finally" >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - >>> # generate up to 30 tokens + >>> # Generate up to 30 tokens >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] ``` - Multinomial Sampling: + Multinomial sampling, modifying an existing generation configuration: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") @@ -1269,17 +1087,20 @@ def generate( >>> prompt = "Today I believe we can finally" >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - >>> # sample up to 30 tokens + >>> # Sample up to 30 tokens >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) + >>> generation_config = GenerationConfig.from_pretrained("gpt2") + >>> generation_config.max_length = 30 + >>> generation_config.do_sample = True + >>> outputs = model.generate(input_ids, generation_config=generation_config) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] ``` - Beam-search decoding: + Beam-search decoding, using a freshly initialized generation configuration: ```python - >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") @@ -1287,75 +1108,86 @@ def generate( >>> sentence = "Paris is one of the densest populated areas in Europe." >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids - >>> outputs = model.generate(input_ids, num_beams=5) + >>> generation_config = GenerationConfig( + ... max_length=64, + ... num_beams=5, + ... bos_token_id=0, + ... eos_token_id=0, + ... decoder_start_token_id=58100, + ... pad_token_id=58100, + ... bad_words_ids=[[58100]], + ... ) + >>> outputs = model.generate(input_ids, generation_config=generation_config) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 0. Validate the `.generate()` call + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) - # 1. Set generation parameters if not already defined - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - num_beams = num_beams if num_beams is not None else self.config.num_beams - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups - do_sample = do_sample if do_sample is not None else self.config.do_sample - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) + # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - - if eos_token_id is None and hasattr(self.config, "decoder"): - eos_token_id = self.config.decoder.eos_token_id - - if pad_token_id is None and eos_token_id is not None: + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id - - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = generation_config.eos_token_id - # 2. Define model inputs + # 3. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) batch_size = inputs_tensor.shape[0] - # 3. Define other model kwargs - model_kwargs["output_attentions"] = output_attentions - model_kwargs["output_hidden_states"] = output_hidden_states - model_kwargs["use_cache"] = use_cache + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, pad_token_id, eos_token_id + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: - if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." @@ -1368,12 +1200,12 @@ def generate( inputs_tensor, model_kwargs, model_input_name ) - # 4. Prepare `input_ids` which will be used for auto-regressive generation + # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, - decoder_start_token_id=decoder_start_token_id, - bos_token_id=bos_token_id, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, model_kwargs=model_kwargs, device=inputs_tensor.device, ) @@ -1381,87 +1213,91 @@ def generate( # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor - # 5. Prepare `max_length` depending on other stopping criteria. + # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - if max_length is None and max_new_tokens is None: + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length == 20 + if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to " - f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " - "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " - "using `max_new_tokens` to control the maximum length of the generation.", + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" " documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - # default to config if still None - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - if min_length is not None and min_length > max_length: + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( - f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" ) - if input_ids_seq_length >= max_length: + if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {max_length}. This can lead to unexpected behavior. You should consider increasing " - "`max_new_tokens`." + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." ) - # 6. determine generation mode - is_constraint_gen_mode = constraints is not None or force_words_ids is not None + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None or generation_config.force_words_ids is not None + ) is_contrastive_search_gen_mode = ( - top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 ) is_greedy_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is False + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is True + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is False + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is True + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_group_beam_gen_mode = ( - (num_beams > 1) - and (num_beam_groups > 1) + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) - if num_beam_groups > num_beams: + if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") - if is_group_beam_gen_mode and do_sample is True: + if is_group_beam_gen_mode and generation_config.do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) @@ -1477,269 +1313,244 @@ def generate( UserWarning, ) - # 7. prepare distribution pre_processing samplers + # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - diversity_penalty=diversity_penalty, - remove_invalid_values=remove_invalid_values, - exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, - renormalize_logits=renormalize_logits, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, ) - # 8. prepare stopping criteria + # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria + generation_config=generation_config, stopping_criteria=stopping_criteria ) - # 9. go into different generation modes + # 10. go into different generation modes if is_greedy_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." ) - # 10. run greedy search + # 11. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_contrastive_search_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." ) return self.contrastive_search( input_ids, - top_k=top_k, - penalty_alpha=penalty_alpha, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - renormalize_logits=renormalize_logits, - ) + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) - # 11. expand input_ids with `num_return_sequences` additional sequences per batch + # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_return_sequences, + expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run sample + # 13. run sample return self.sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 10. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 13. run beam search return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - renormalize_logits=renormalize_logits, - ) + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 11. prepare beam search scorer + # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( - batch_size=batch_size * num_return_sequences, - num_beams=num_beams, + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, ) - # 12. interleave input_ids with `num_beams` additional sequences per batch + # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams * num_return_sequences, + expand_size=generation_config.num_beams * generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 13. run beam sample + # 14. run beam sample return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - if num_beams % num_beam_groups != 0: + if generation_config.num_beams % generation_config.num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - if typical_p is not None: + has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + if not has_default_typical_p: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") - # 10. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, max_length=stopping_criteria.max_length, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 13. run beam search return self.group_beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_constraint_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - if num_beams <= 1: + if generation_config.num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") - if do_sample: + if generation_config.do_sample: raise ValueError("`do_sample` needs to be false for constrained generation.") - if num_beam_groups is not None and num_beam_groups > 1: + if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] - if constraints is not None: - final_constraints = constraints + if generation_config.constraints is not None: + final_constraints = generation_config.constraints - if force_words_ids is not None: + if generation_config.force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" - f"of positive integers, but is {force_words_ids}." + f"of positive integers, but is {generation_config.force_words_ids}." ) - if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): typeerror() - for word_ids in force_words_ids: + for word_ids in generation_config.force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() @@ -1761,33 +1572,33 @@ def typeerror(): constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) - # 10. prepare beam search scorer + # 11. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 13. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) @@ -1884,15 +1695,19 @@ def contrastive_search( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2239,15 +2054,19 @@ def greedy_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2487,15 +2306,19 @@ def sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2739,15 +2562,19 @@ def beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3059,15 +2886,19 @@ def beam_sample( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3368,15 +3199,19 @@ def group_beam_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3737,15 +3572,19 @@ def constrained_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 48e437fd768478..6780f9b19f1431 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -39,7 +39,7 @@ from .configuration_utils import PretrainedConfig from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .dynamic_module_utils import custom_object_save -from .generation import GenerationMixin +from .generation import GenerationConfig, GenerationMixin from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -1024,6 +1024,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None def post_init(self): """ @@ -1106,6 +1107,18 @@ def base_model(self) -> nn.Module: """ return getattr(self, self.base_model_prefix, self) + def can_generate(self) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation + if "GenerationMixin" in str(self.prepare_inputs_for_generation): + return False + return True + def get_input_embeddings(self) -> nn.Module: """ Returns the model's input embeddings. @@ -2477,6 +2490,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + # Dispatch model with hooks on all devices if necessary if device_map is not None: dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index c4b102a204f683..461e06ec4f7535 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -14,6 +14,7 @@ # limitations under the License. """RAG model implementation.""" +import copy from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from torch import nn from ...configuration_utils import PretrainedConfig -from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList +from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -1384,33 +1385,12 @@ def generate( context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - early_stopping: Optional[bool] = None, - use_cache: Optional[bool] = None, - num_beams: Optional[int] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[List[List[int]]] = None, - num_return_sequences: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, n_docs: Optional[int] = None, + generation_config: Optional[GenerationConfig] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), - renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, - exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, - **model_kwargs + **kwargs ) -> torch.LongTensor: """ Implements RAG token decoding. @@ -1444,51 +1424,15 @@ def generate( If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - min_length (`int`, *optional*, defaults to 10): - The minimum length of the sequence to be generated. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or - not. - use_cache: (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size can only occur once. - encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the - `decoder_input_ids`. - bad_words_ids(`List[int]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the tokens of the words that - should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - num_beam_groups (`int`, *optional*, defaults to 1): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of - beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - diversity_penalty (`float`, *optional*, defaults to 0.0): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is - enabled. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. Note that this - is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where - we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an - encoder-decoder model starts decoding with a different token than *bos*, the id of that token. n_docs (`int`, *optional*, defaults to `config.n_docs`) Number of documents to retrieve and/or number of documents for which to generate an answer. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID @@ -1497,53 +1441,30 @@ def generate( constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - model's config. If a stopping criteria is passed that is already created with the arguments or a - model's config an error is thrown. - forced_bos_token_id (`int`, *optional*): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*): - The id of the token to force as the last generated token when `max_length` is reached. - remove_invalid_values (`bool`, *optional*): - Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to - crash. Note that using `remove_invalid_values` can slow down generation. + Custom stopping criteria that complement the default stopping criteria built from arguments and a + model's config. If a stopping criteria is passed that is already created with the arguments or a + model's config an error is thrown. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. Return: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. """ + # Handle `generation_config` and kwargs that might update it + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs - num_beams = num_beams if num_beams is not None else self.config.num_beams - num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups - max_length = max_length if max_length is not None else self.config.max_length - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id - use_cache = use_cache if use_cache is not None else self.config.use_cache - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.config.generator.decoder_start_token_id - ) - remove_invalid_values = ( - remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) # retrieve docs if self.retriever is not None and context_input_ids is None: @@ -1583,8 +1504,8 @@ def generate( encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) input_ids = torch.full( - (batch_size * num_beams, 1), - decoder_start_token_id, + (batch_size * generation_config.num_beams, 1), + generation_config.decoder_start_token_id, dtype=torch.long, device=next(self.parameters()).device, ) @@ -1600,10 +1521,12 @@ def extend_enc_output(tensor, num_beams=None): return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:]) # correctly extend last_hidden_state and attention mask - context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) - encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) + context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams) + encoder_outputs["last_hidden_state"] = extend_enc_output( + last_hidden_state, num_beams=generation_config.num_beams + ) - doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) + doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0) # define start_len & additional parameters model_kwargs["doc_scores"] = doc_scores @@ -1612,64 +1535,51 @@ def extend_enc_output(tensor, num_beams=None): model_kwargs["n_docs"] = n_docs pre_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=context_input_ids, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - diversity_penalty=diversity_penalty, - remove_invalid_values=remove_invalid_values, - exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, - renormalize_logits=renormalize_logits, ) - if num_beams == 1: - if num_return_sequences > 1: + if generation_config.num_beams == 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." ) return self.greedy_search( input_ids, logits_processor=pre_processor, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, **model_kwargs, ) - elif num_beams > 1: - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - if num_return_sequences > num_beams: + elif generation_config.num_beams > 1: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) return self.beam_search( input_ids, beam_scorer, logits_processor=pre_processor, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, **model_kwargs, ) else: - raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}") + raise ValueError( + f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}" + ) def get_input_embeddings(self): return self.rag.generator.get_input_embeddings() diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 5cfe0995655f14..004720e110b973 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import tempfile import unittest from parameterized import parameterized -from transformers.generation import GenerationConfig +from transformers import AutoConfig, GenerationConfig class LogitsProcessorTest(unittest.TestCase): @@ -43,3 +44,33 @@ def test_save_load_config(self, config_name): self.assertEqual(loaded_config.top_k, 50) self.assertEqual(loaded_config.max_length, 20) self.assertEqual(loaded_config.max_time, None) + + def test_from_model_config(self): + model_config = AutoConfig.from_pretrained("gpt2") + generation_config_from_model = GenerationConfig.from_model_config(model_config) + default_generation_config = GenerationConfig() + + # The generation config has loaded a few non-default parameters from the model config + self.assertNotEqual(generation_config_from_model, default_generation_config) + + # One of those parameters is eos_token_id -- check if it matches + self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id) + self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id) + + def test_update(self): + generation_config = GenerationConfig() + update_kwargs = { + "max_new_tokens": 1024, + "foo": "bar", + } + update_kwargs_copy = copy.deepcopy(update_kwargs) + unused_kwargs = generation_config.update(**update_kwargs) + + # update_kwargs was not modified (no side effects) + self.assertEqual(update_kwargs, update_kwargs_copy) + + # update_kwargs was used to update the config on valid attributes + self.assertEqual(generation_config.max_new_tokens, 1024) + + # `.update()` returns a dictionary of unused kwargs + self.assertEqual(unused_kwargs, {"foo": "bar"})