diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx
index 78bef8bd5a2537..1d00406ac1e584 100644
--- a/docs/source/en/main_classes/text_generation.mdx
+++ b/docs/source/en/main_classes/text_generation.mdx
@@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
-
+Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
+class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
+of the generation method.
+
+All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
+model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
+like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
+store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
+
+```python
+from transformers import AutoModelForCausalLM, GenerationConfig
+
+model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
+
+# Inspect the default generation configuration
+print(model.generation_config)
+
+# Set a new default generation configuration
+generation_config = GenerationConfig(
+ max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
+)
+generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
+```
+
+
+
+If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
+default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
+into resource limitations. Make sure you double-check the defaults in the documentation.
+
+
+
+You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
+argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
+store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
+other for summarization with beam search).
+
+```python
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
+
+tokenizer = AutoTokenizer.from_pretrained("t5-small")
+model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
+
+translation_generation_config = GenerationConfig(
+ num_beams=4,
+ early_stopping=True,
+ decoder_start_token_id=0,
+ eos_token_id=model.config.eos_token_id,
+ pad_token=model.config.pad_token_id,
+)
+# If you were working on a model for which your had the right Hub permissions, you could store a named generation
+# config as follows
+translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
+
+# You could then use the named generation config file to parameterize generation
+generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
+inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
+outputs = model.generate(**inputs, generation_config=generation_config)
+print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
+# ['Les fichiers de configuration sont faciles à utiliser !']
+```
+
+Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
+wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
+framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
+to parameterize it.
+
## GenerationConfig
[[autodoc]] generation.GenerationConfig
- from_pretrained
+ - from_model_config
- save_pretrained
## GenerationMixin
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index 07a97c7f2522f8..a477ebe4203c43 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -20,6 +20,7 @@
from typing import Any, Dict, Optional, Union
from .. import __version__
+from ..configuration_utils import PretrainedConfig
from ..utils import (
GENERATION_CONFIG_NAME,
PushToHubMixin,
@@ -36,7 +37,23 @@
class GenerationConfig(PushToHubMixin):
r"""
- Class that holds a configuration for a generation task.
+ Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
+ for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
+
+ - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
+ `do_sample=False`.
+ - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
+ and `top_k>1`
+ - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
+ `do_sample=True`.
+ - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
+ `do_sample=False`.
+ - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
+ `num_beams>1` and `do_sample=True`.
+ - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
+ `num_beams>1` and `num_beam_groups>1`.
+ - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
+ `constraints!=None` or `force_words_ids!=None`.
@@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin):
+ Most of these parameters are explained in more detail in [this blog
+ post](https://huggingface.co/blog/how-to-generate).
+
Arg:
> Parameters that control the length of the output
@@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should use the past last key/values attentions (if applicable to the model) to
+ speed up decoding.
> Parameters for manipulation of the model output logits
@@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin):
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should use the past last key/values attentions (if applicable to the model) to
- speed up decoding.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
+ constraints (`List[Constraint]`, *optional*):
+ Custom constraints that can be added to the generation to ensure that the output will contain the use of
+ certain tokens as defined by `Constraint` objects, in the most sensible way possible.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
@@ -191,6 +214,7 @@ def __init__(self, **kwargs):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
+ self.use_cache = kwargs.pop("use_cache", True)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
@@ -202,7 +226,9 @@ def __init__(self, **kwargs):
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
- self.force_word_ids = kwargs.pop("force_word_ids", None)
+ self.force_words_ids = kwargs.pop("force_words_ids", None)
+ self.renormalize_logits = kwargs.pop("renormalize_logits", False)
+ self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
@@ -230,12 +256,20 @@ def __init__(self, **kwargs):
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
- # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface.
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub
+ # interface.
+ self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
def __eq__(self, other):
- return self.__dict__ == other.__dict__
+ self_dict = self.__dict__.copy()
+ other_dict = other.__dict__.copy()
+ # ignore metadata
+ for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
+ self_dict.pop(metadata_field, None)
+ other_dict.pop(metadata_field, None)
+ return self_dict == other_dict
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@@ -484,18 +518,11 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict)
-
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
+ unused_kwargs = config.update(**kwargs)
logger.info(f"Generate config {config}")
if return_unused_kwargs:
- return config, kwargs
+ return config, unused_kwargs
else:
return config
@@ -568,3 +595,54 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool =
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
+
+ @classmethod
+ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
+ """
+ Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
+ [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
+
+ Args:
+ model_config (`PretrainedConfig`):
+ The model config that will be used to instantiate the generation config.
+
+ Returns:
+ [`GenerationConfig`]: The configuration object instantiated from those parameters.
+ """
+ config_dict = model_config.to_dict()
+ config = cls.from_dict(config_dict, return_unused_kwargs=False)
+
+ # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
+ # generation config.
+ for decoder_name in ("decoder", "generator"):
+ if decoder_name in config_dict:
+ default_generation_config = GenerationConfig()
+ decoder_config = config_dict[decoder_name]
+ for attr in config.to_dict().keys():
+ if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
+ setattr(config, attr, decoder_config[attr])
+
+ config._from_model_config = True
+ return config
+
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 3d945b2be37a74..03ad4a25a1d944 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import warnings
from dataclasses import dataclass
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -33,8 +34,9 @@
)
from ..pytorch_utils import torch_int_div
from ..utils import ModelOutput, logging
-from .beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
+from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
+from .configuration_utils import GenerationConfig
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
@@ -478,6 +480,11 @@ class GenerationMixin:
`constraints!=None` or `force_words_ids!=None`.
"""
+ def prepare_inputs_for_generation(self, *args, **kwargs):
+ raise NotImplementedError(
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
+ )
+
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
@@ -620,26 +627,16 @@ def _prepare_decoder_input_ids_for_generation(
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
- decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
+ decoder_start_token_id
+ if decoder_start_token_id is not None
+ else self.generation_config.decoder_start_token_id
)
- bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+ bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
- elif (
- hasattr(self.config, "decoder")
- and hasattr(self.config.decoder, "decoder_start_token_id")
- and self.config.decoder.decoder_start_token_id is not None
- ):
- return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
- elif (
- hasattr(self.config, "decoder")
- and hasattr(self.config.decoder, "bos_token_id")
- and self.config.decoder.bos_token_id is not None
- ):
- return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@@ -722,166 +719,149 @@ def _reorder_cache(self, past, beam_idx):
def _get_logits_warper(
self,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- typical_p: Optional[float] = None,
- temperature: Optional[float] = None,
- num_beams: Optional[int] = None,
- renormalize_logits: Optional[bool] = None,
+ generation_config: GenerationConfig,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
- # init warp parameters
- top_k = top_k if top_k is not None else self.config.top_k
- top_p = top_p if top_p is not None else self.config.top_p
- typical_p = typical_p if typical_p is not None else self.config.typical_p
- temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
- if temperature is not None and temperature != 1.0:
- warpers.append(TemperatureLogitsWarper(temperature))
- if top_k is not None and top_k != 0:
- warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
- if top_p is not None and top_p < 1.0:
- warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
- if typical_p is not None and typical_p < 1.0:
- warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ warpers.append(TemperatureLogitsWarper(generation_config.temperature))
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ warpers.append(
+ TopKLogitsWarper(
+ top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
+ )
+ )
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ warpers.append(
+ TopPLogitsWarper(
+ top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
+ )
+ )
+ if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
+ warpers.append(
+ TypicalLogitsWarper(
+ mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
+ )
+ )
# `LogitNormalization` should always be the last logit processor, when present
- if renormalize_logits is True:
+ if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def _get_logits_processor(
self,
- repetition_penalty: float,
- no_repeat_ngram_size: int,
- encoder_no_repeat_ngram_size: int,
+ generation_config: GenerationConfig,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
- bad_words_ids: List[List[int]],
- min_length: int,
- max_length: int,
- eos_token_id: int,
- forced_bos_token_id: int,
- forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
- num_beams: int,
- num_beam_groups: int,
- diversity_penalty: float,
- remove_invalid_values: bool,
- exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
- renormalize_logits: Optional[bool],
- suppress_tokens: Optional[List[int]] = None,
- begin_suppress_tokens: Optional[List[int]] = None,
- forced_decoder_ids: Optional[List[List[int]]] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
instances used to modify the scores of the language model head.
"""
- processors = LogitsProcessorList()
-
- # init warp parameters
- repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
- no_repeat_ngram_size = (
- no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
- )
- encoder_no_repeat_ngram_size = (
- encoder_no_repeat_ngram_size
- if encoder_no_repeat_ngram_size is not None
- else self.config.encoder_no_repeat_ngram_size
- )
- bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
- forced_bos_token_id = (
- forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
- )
- forced_eos_token_id = (
- forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
- )
- remove_invalid_values = (
- remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
- )
- exponential_decay_length_penalty = (
- exponential_decay_length_penalty
- if exponential_decay_length_penalty is not None
- else self.config.exponential_decay_length_penalty
- )
- suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
- begin_suppress_tokens = (
- begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
- )
- if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
- forced_decoder_ids = self.config.forced_decoder_ids
# instantiate processors list
+ processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
- if diversity_penalty is not None and diversity_penalty > 0.0:
+ if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
processors.append(
HammingDiversityLogitsProcessor(
- diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
+ diversity_penalty=generation_config.diversity_penalty,
+ num_beams=generation_config.num_beams,
+ num_beam_groups=generation_config.num_beam_groups,
)
)
- if repetition_penalty is not None and repetition_penalty != 1.0:
- processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
- if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
- processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
- if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
+ if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
+ processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
+ if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
+ processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
+ if (
+ generation_config.encoder_no_repeat_ngram_size is not None
+ and generation_config.encoder_no_repeat_ngram_size > 0
+ ):
if self.config.is_encoder_decoder:
- processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
+ processors.append(
+ EncoderNoRepeatNGramLogitsProcessor(
+ generation_config.encoder_no_repeat_ngram_size, encoder_input_ids
+ )
+ )
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
- if bad_words_ids is not None:
- processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
- if min_length is not None and eos_token_id is not None and min_length > 0:
- processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
+ if generation_config.bad_words_ids is not None:
+ processors.append(
+ NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
+ )
+ if (
+ generation_config.min_length is not None
+ and generation_config.eos_token_id is not None
+ and generation_config.min_length > 0
+ ):
+ processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if prefix_allowed_tokens_fn is not None:
- processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
- if forced_bos_token_id is not None:
- processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
- if forced_eos_token_id is not None:
- processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
- if remove_invalid_values is True:
+ processors.append(
+ PrefixConstrainedLogitsProcessor(
+ prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
+ )
+ )
+ if generation_config.forced_bos_token_id is not None:
+ processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
+ if generation_config.forced_eos_token_id is not None:
+ processors.append(
+ ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
+ )
+ if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
- if exponential_decay_length_penalty is not None:
+ if generation_config.exponential_decay_length_penalty is not None:
processors.append(
- ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
+ ExponentialDecayLengthPenalty(
+ generation_config.exponential_decay_length_penalty,
+ generation_config.eos_token_id,
+ generation_config.input_ids_seq_length,
+ )
)
- if suppress_tokens is not None:
- processors.append(SuppressTokensLogitsProcessor(suppress_tokens))
- if begin_suppress_tokens is not None:
+ if generation_config.suppress_tokens is not None:
+ processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
+ if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
- begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
- if forced_decoder_ids is not None:
- begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
- processors.append(SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
- if forced_decoder_ids is not None:
- processors.append(ForceTokensLogitsProcessor(forced_decoder_ids))
+ begin_index = (
+ begin_index
+ if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
+ else begin_index + 1
+ )
+ if generation_config.forced_decoder_ids is not None:
+ # generation starts after the last token that is forced
+ begin_index += generation_config.forced_decoder_ids[-1][0]
+ processors.append(
+ SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
+ )
+ if generation_config.forced_decoder_ids is not None:
+ processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
- if renormalize_logits is True:
+ if generation_config.renormalize_logits is True:
processors.append(LogitNormalization())
return processors
def _get_stopping_criteria(
- self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
+ self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
- if max_length is not None:
- criteria.append(MaxLengthCriteria(max_length=max_length))
- if max_time is not None:
- criteria.append(MaxTimeCriteria(max_time=max_time))
+ if generation_config.max_length is not None:
+ criteria.append(MaxLengthCriteria(max_length=generation_config.max_length))
+ if generation_config.max_time is not None:
+ criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
@@ -951,7 +931,7 @@ def _validate_model_class(self):
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
- if not hasattr(self, "prepare_inputs_for_generation"):
+ if not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
@@ -999,81 +979,27 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
- max_length: Optional[int] = None,
- min_length: Optional[int] = None,
- do_sample: Optional[bool] = None,
- early_stopping: Optional[bool] = None,
- num_beams: Optional[int] = None,
- temperature: Optional[float] = None,
- penalty_alpha: Optional[float] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- typical_p: Optional[float] = None,
- repetition_penalty: Optional[float] = None,
- bad_words_ids: Optional[Iterable[int]] = None,
- force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
- bos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- length_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- encoder_no_repeat_ngram_size: Optional[int] = None,
- num_return_sequences: Optional[int] = None,
- max_time: Optional[float] = None,
- max_new_tokens: Optional[int] = None,
- decoder_start_token_id: Optional[int] = None,
- use_cache: Optional[bool] = None,
- num_beam_groups: Optional[int] = None,
- diversity_penalty: Optional[float] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
- renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
- constraints: Optional[List[Constraint]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- forced_bos_token_id: Optional[int] = None,
- forced_eos_token_id: Optional[int] = None,
- remove_invalid_values: Optional[bool] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = False,
- exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
- suppress_tokens: Optional[List[int]] = None,
- begin_suppress_tokens: Optional[List[int]] = None,
- forced_decoder_ids: Optional[List[List[int]]] = None,
- **model_kwargs,
+ **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
- Generates sequences of token ids for models with a language modeling head. The method supports the following
- generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
-
- - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- `do_sample=False`.
- - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
- and `top_k>1`
- - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- `do_sample=True`.
- - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- `do_sample=False`.
- - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
- `num_beams>1` and `do_sample=True`.
- - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
- `num_beams>1` and `num_beam_groups>1`.
- - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- `constraints!=None` or `force_words_ids!=None`.
+ Generates sequences of token ids for models with a language modeling head.
- Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
- defined in the model's config (`config.json`) which in turn defaults to the
- [`~modeling_utils.PretrainedConfig`] of the model.
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
-
+ For a complete overview of generate, check the [following
+ guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
- Most of these parameters are explained in more detail in [this blog
- post](https://huggingface.co/blog/how-to-generate).
+
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
@@ -1081,81 +1007,21 @@ def generate(
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
- max_length (`int`, *optional*, defaults to `model.config.max_length`):
- The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
- `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
- the prompt.
- max_new_tokens (`int`, *optional*):
- The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
- min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value):
- The minimum length of the sequence to be generated.
- do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value):
- Whether or not to use sampling ; use greedy decoding otherwise.
- early_stopping (`bool`, *optional*, defaults to `False`):
- Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
- num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value):
- Number of beams for beam search. 1 means no beam search.
- temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
- The value used to module the next token probabilities.
- penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value):
- The values balance the model confidence and the degeneration penalty in contrastive search decoding.
- top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
- The number of highest probability vocabulary tokens to keep for top-k-filtering.
- top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
- If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
- `top_p` or higher are kept for generation.
- typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
- The amount of probability mass from the original distribution to be considered in typical decoding. If
- set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
- repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value):
- The parameter for repetition penalty. 1.0 means no penalty. See [this
- paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
- pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`):
- The id of the *padding* token.
- bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`):
- The id of the *beginning-of-sequence* token.
- eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`):
- The id of the *end-of-sequence* token.
- length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value):
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
- to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
- the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
- while `length_penalty` < 0.0 encourages shorter sequences.
- no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value):
- If set to int > 0, all ngrams of that size can only occur once.
- encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value):
- If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
- `decoder_input_ids`.
- bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`):
- List of token ids that are not allowed to be generated. In order to get the token ids of the words that
- should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
- add_special_tokens=False).input_ids`.
- force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
- List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
- list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
- this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
- where one can allow different forms of each word.
- num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value):
- The number of independently computed returned sequences for each element in the batch.
- max_time(`float`, *optional*):
- The maximum amount of time you allow the computation to run for in seconds. generation will still
- finish the current pass after allocated time has been passed.
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
- that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
- as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
- decoder_start_token_id (`int`, *optional*):
- If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should use the past last key/values attentions (if applicable to the model) to
- speed up decoding.
- num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value):
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
- beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
- diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value):
- This value is subtracted from a beam's score if it generates a token same as any beam from other group
- at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
- enabled.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
@@ -1163,60 +1029,12 @@ def generate(
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
- logits_processor (`LogitsProcessorList`, *optional*):
- Custom logits processors that complement the default logits processors built from arguments and a
- model's config. If a logit processor is passed that is already created with the arguments or a model's
- config an error is thrown. This feature is intended for advanced users.
- renormalize_logits (`bool`, *optional*, defaults to `False`):
- Whether to renormalize the logits after applying all the logits processors or warpers (including the
- custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
- score logits are normalized but some logit processors or warpers break the normalization.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- Custom stopping criteria that complement the default stopping criteria built from arguments and a
- model's config. If a stopping criteria is passed that is already created with the arguments or a
- model's config an error is thrown. This feature is intended for advanced users.
- constraints (`List[Constraint]`, *optional*):
- Custom constraints that can be added to the generation to ensure that the output will contain the use
- of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
- output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
- The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
- for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
- the target language token.
- forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
- The id of the token to force as the last generated token when `max_length` is reached.
- remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
- Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
- crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
- This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
- generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
- where penalty starts and `decay_factor` represents the factor of exponential decay
- suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
- A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
- their log probs to `-inf` so that they are not sampled.
- begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
- A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
- logit processor will set their log probs to `-inf` so that they are not sampled.
- forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
- A list of pairs of integers which indicates a mapping from generation indices to token indices that
- will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
- be a token of index 123.
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
- is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
- should be prefixed with *decoder_*.
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
@@ -1240,7 +1058,7 @@ def generate(
Examples:
- Greedy Decoding:
+ Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -1251,16 +1069,16 @@ def generate(
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
- >>> # generate up to 30 tokens
+ >>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
- Multinomial Sampling:
+ Multinomial sampling, modifying an existing generation configuration:
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -1269,17 +1087,20 @@ def generate(
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
- >>> # sample up to 30 tokens
+ >>> # Sample up to 30 tokens
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
- >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
+ >>> generation_config = GenerationConfig.from_pretrained("gpt2")
+ >>> generation_config.max_length = 30
+ >>> generation_config.do_sample = True
+ >>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
```
- Beam-search decoding:
+ Beam-search decoding, using a freshly initialized generation configuration:
```python
- >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+ >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
@@ -1287,75 +1108,86 @@ def generate(
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
- >>> outputs = model.generate(input_ids, num_beams=5)
+ >>> generation_config = GenerationConfig(
+ ... max_length=64,
+ ... num_beams=5,
+ ... bos_token_id=0,
+ ... eos_token_id=0,
+ ... decoder_start_token_id=58100,
+ ... pad_token_id=58100,
+ ... bad_words_ids=[[58100]],
+ ... )
+ >>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
- # 0. Validate the `.generate()` call
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation -- update the generation config
+ # model attribute accordingly, if it was created from the model config
+ if self.generation_config._from_model_config:
+ new_generation_config = GenerationConfig.from_model_config(self.config)
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use a generation configuration file (see"
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())
- # 1. Set generation parameters if not already defined
- bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
- num_beams = num_beams if num_beams is not None else self.config.num_beams
- length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
- early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
- num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
- do_sample = do_sample if do_sample is not None else self.config.do_sample
- num_return_sequences = (
- num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
- )
+ # 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
-
- if eos_token_id is None and hasattr(self.config, "decoder"):
- eos_token_id = self.config.decoder.eos_token_id
-
- if pad_token_id is None and eos_token_id is not None:
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
- logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
- pad_token_id = eos_token_id
-
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
- )
+ logger.warning(
+ f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation."
+ )
+ generation_config.pad_token_id = generation_config.eos_token_id
- # 2. Define model inputs
+ # 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
- inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
batch_size = inputs_tensor.shape[0]
- # 3. Define other model kwargs
- model_kwargs["output_attentions"] = output_attentions
- model_kwargs["output_hidden_states"] = output_hidden_states
- model_kwargs["use_cache"] = use_cache
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- inputs_tensor, pad_token_id, eos_token_id
+ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
- if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0:
+ if (
+ generation_config.pad_token_id is not None
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
+ ):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
@@ -1368,12 +1200,12 @@ def generate(
inputs_tensor, model_kwargs, model_input_name
)
- # 4. Prepare `input_ids` which will be used for auto-regressive generation
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
- decoder_start_token_id=decoder_start_token_id,
- bos_token_id=bos_token_id,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
@@ -1381,87 +1213,91 @@ def generate(
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
- # 5. Prepare `max_length` depending on other stopping criteria.
+ # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
- if max_length is None and max_new_tokens is None:
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length == 20
+ if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
- "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
- f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
- "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
- "using `max_new_tokens` to control the maximum length of the generation.",
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
- elif max_length is None and max_new_tokens is not None:
- max_length = max_new_tokens + input_ids_seq_length
- elif max_length is not None and max_new_tokens is not None:
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
- # default to config if still None
- max_length = max_length if max_length is not None else self.config.max_length
- min_length = min_length if min_length is not None else self.config.min_length
- if min_length is not None and min_length > max_length:
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
- f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum "
- f"length ({max_length})"
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
)
- if input_ids_seq_length >= max_length:
+ if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
- f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
- "`max_new_tokens`."
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
)
- # 6. determine generation mode
- is_constraint_gen_mode = constraints is not None or force_words_ids is not None
+ # 7. determine generation mode
+ is_constraint_gen_mode = (
+ generation_config.constraints is not None or generation_config.force_words_ids is not None
+ )
is_contrastive_search_gen_mode = (
- top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
+ generation_config.top_k is not None
+ and generation_config.top_k > 1
+ and generation_config.do_sample is False
+ and generation_config.penalty_alpha is not None
+ and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
- (num_beams == 1)
- and (num_beam_groups == 1)
- and do_sample is False
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_sample_gen_mode = (
- (num_beams == 1)
- and (num_beam_groups == 1)
- and do_sample is True
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
- (num_beams > 1)
- and (num_beam_groups == 1)
- and do_sample is False
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
- (num_beams > 1)
- and (num_beam_groups == 1)
- and do_sample is True
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
- (num_beams > 1)
- and (num_beam_groups > 1)
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
- if num_beam_groups > num_beams:
+ if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
- if is_group_beam_gen_mode and do_sample is True:
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
@@ -1477,269 +1313,244 @@ def generate(
UserWarning,
)
- # 7. prepare distribution pre_processing samplers
+ # 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
+ generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
- bad_words_ids=bad_words_ids,
- min_length=min_length,
- max_length=max_length,
- eos_token_id=eos_token_id,
- forced_bos_token_id=forced_bos_token_id,
- forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- num_beams=num_beams,
- num_beam_groups=num_beam_groups,
- diversity_penalty=diversity_penalty,
- remove_invalid_values=remove_invalid_values,
- exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
- renormalize_logits=renormalize_logits,
- suppress_tokens=suppress_tokens,
- begin_suppress_tokens=begin_suppress_tokens,
- forced_decoder_ids=forced_decoder_ids,
)
- # 8. prepare stopping criteria
+ # 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
- max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
+ generation_config=generation_config, stopping_criteria=stopping_criteria
)
- # 9. go into different generation modes
+ # 10. go into different generation modes
if is_greedy_gen_mode:
- if num_return_sequences > 1:
+ if generation_config.num_return_sequences > 1:
raise ValueError(
- f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
)
- # 10. run greedy search
+ # 11. run greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
- if num_return_sequences > 1:
+ if generation_config.num_return_sequences > 1:
raise ValueError(
- f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " contrastive search."
)
return self.contrastive_search(
input_ids,
- top_k=top_k,
- penalty_alpha=penalty_alpha,
+ top_k=generation_config.top_k,
+ penalty_alpha=generation_config.penalty_alpha,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
- # 10. prepare logits warper
- logits_warper = self._get_logits_warper(
- top_k=top_k,
- top_p=top_p,
- typical_p=typical_p,
- temperature=temperature,
- num_beams=num_beams,
- renormalize_logits=renormalize_logits,
- )
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
- # 11. expand input_ids with `num_return_sequences` additional sequences per batch
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
- expand_size=num_return_sequences,
+ expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
- # 12. run sample
+ # 13. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_gen_mode:
- if num_return_sequences > num_beams:
+ if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
- # 10. prepare beam search scorer
+ # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
- num_beams=num_beams,
+ num_beams=generation_config.num_beams,
device=inputs_tensor.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
- # 11. interleave input_ids with `num_beams` additional sequences per batch
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
- expand_size=num_beams,
+ expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
- # 12. run beam search
+ # 13. run beam search
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
- # 10. prepare logits warper
- logits_warper = self._get_logits_warper(
- top_k=top_k,
- top_p=top_p,
- typical_p=typical_p,
- temperature=temperature,
- num_beams=num_beams,
- renormalize_logits=renormalize_logits,
- )
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
- # 11. prepare beam search scorer
+ # 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
- batch_size=batch_size * num_return_sequences,
- num_beams=num_beams,
+ batch_size=batch_size * generation_config.num_return_sequences,
+ num_beams=generation_config.num_beams,
device=inputs_tensor.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
)
- # 12. interleave input_ids with `num_beams` additional sequences per batch
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
- expand_size=num_beams * num_return_sequences,
+ expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
- # 13. run beam sample
+ # 14. run beam sample
return self.beam_sample(
input_ids,
beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_group_beam_gen_mode:
- if num_return_sequences > num_beams:
+ if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
- if num_beams % num_beam_groups != 0:
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
- if typical_p is not None:
+ has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
+ if not has_default_typical_p:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
- # 10. prepare beam search scorer
+ # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
- num_beams=num_beams,
+ num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
- num_beam_groups=num_beam_groups,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ num_beam_groups=generation_config.num_beam_groups,
)
- # 11. interleave input_ids with `num_beams` additional sequences per batch
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
- expand_size=num_beams,
+ expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
- # 12. run beam search
+ # 13. run beam search
return self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_constraint_gen_mode:
- if num_return_sequences > num_beams:
+ if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
- if num_beams <= 1:
+ if generation_config.num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
- if do_sample:
+ if generation_config.do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
- if num_beam_groups is not None and num_beam_groups > 1:
+ if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
- if constraints is not None:
- final_constraints = constraints
+ if generation_config.constraints is not None:
+ final_constraints = generation_config.constraints
- if force_words_ids is not None:
+ if generation_config.force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
- f"of positive integers, but is {force_words_ids}."
+ f"of positive integers, but is {generation_config.force_words_ids}."
)
- if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
+ if (
+ not isinstance(generation_config.force_words_ids, list)
+ or len(generation_config.force_words_ids) == 0
+ ):
typeerror()
- for word_ids in force_words_ids:
+ for word_ids in generation_config.force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
@@ -1761,33 +1572,33 @@ def typeerror():
constraint = PhrasalConstraint(word_ids)
final_constraints.append(constraint)
- # 10. prepare beam search scorer
+ # 11. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=final_constraints,
batch_size=batch_size,
- num_beams=num_beams,
+ num_beams=generation_config.num_beams,
device=inputs_tensor.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
- # 11. interleave input_ids with `num_beams` additional sequences per batch
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
- expand_size=num_beams,
+ expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
- # 12. run beam search
+ # 13. run beam search
return self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
@@ -1884,15 +1695,19 @@ def contrastive_search(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
@@ -2239,15 +2054,19 @@ def greedy_search(
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
@@ -2487,15 +2306,19 @@ def sample(
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
@@ -2739,15 +2562,19 @@ def beam_search(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
@@ -3059,15 +2886,19 @@ def beam_sample(
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
@@ -3368,15 +3199,19 @@ def group_beam_search(
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
@@ -3737,15 +3572,19 @@ def constrained_beam_search(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- output_scores = output_scores if output_scores is not None else self.config.output_scores
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
- return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 48e437fd768478..6780f9b19f1431 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -39,7 +39,7 @@
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
-from .generation import GenerationMixin
+from .generation import GenerationConfig, GenerationMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -1024,6 +1024,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
def post_init(self):
"""
@@ -1106,6 +1107,18 @@ def base_model(self) -> nn.Module:
"""
return getattr(self, self.base_model_prefix, self)
+ def can_generate(self) -> bool:
+ """
+ Returns whether this model can generate sequences with `.generate()`.
+
+ Returns:
+ `bool`: Whether this model can generate sequences with `.generate()`.
+ """
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
+ if "GenerationMixin" in str(self.prepare_inputs_for_generation):
+ return False
+ return True
+
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
@@ -2477,6 +2490,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
+ # If it is a model with generation capabilities, attempt to load the generation config
+ if model.can_generate():
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ **kwargs,
+ )
+ except OSError:
+ logger.info(
+ "Generation config file not found, using a generation config created from the model config."
+ )
+ pass
+
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py
index c4b102a204f683..461e06ec4f7535 100644
--- a/src/transformers/models/rag/modeling_rag.py
+++ b/src/transformers/models/rag/modeling_rag.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""RAG model implementation."""
+import copy
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
@@ -21,7 +22,7 @@
from torch import nn
from ...configuration_utils import PretrainedConfig
-from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
+from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@@ -1384,33 +1385,12 @@ def generate(
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
- max_length: Optional[int] = None,
- min_length: Optional[int] = None,
- early_stopping: Optional[bool] = None,
- use_cache: Optional[bool] = None,
- num_beams: Optional[int] = None,
- num_beam_groups: Optional[int] = None,
- diversity_penalty: Optional[float] = None,
- bos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- length_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- encoder_no_repeat_ngram_size: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- bad_words_ids: Optional[List[List[int]]] = None,
- num_return_sequences: Optional[int] = None,
- decoder_start_token_id: Optional[int] = None,
n_docs: Optional[int] = None,
+ generation_config: Optional[GenerationConfig] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
- renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
- forced_bos_token_id: Optional[int] = None,
- forced_eos_token_id: Optional[int] = None,
- remove_invalid_values: Optional[bool] = None,
- exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
- **model_kwargs
+ **kwargs
) -> torch.LongTensor:
"""
Implements RAG token decoding.
@@ -1444,51 +1424,15 @@ def generate(
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
- max_length (`int`, *optional*, defaults to 20):
- The maximum length of the sequence to be generated.
- min_length (`int`, *optional*, defaults to 10):
- The minimum length of the sequence to be generated.
- early_stopping (`bool`, *optional*, defaults to `False`):
- Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
- not.
- use_cache: (`bool`, *optional*, defaults to `True`):
- Whether or not the model should use the past last key/values attentions (if applicable to the model) to
- speed up decoding.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- bos_token_id (`int`, *optional*):
- The id of the *beginning-of-sequence* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- length_penalty (`float`, *optional*, defaults to 1.0):
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
- to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
- the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
- while `length_penalty` < 0.0 encourages shorter sequences.
- no_repeat_ngram_size (`int`, *optional*, defaults to 0):
- If set to int > 0, all ngrams of that size can only occur once.
- encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
- If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
- `decoder_input_ids`.
- bad_words_ids(`List[int]`, *optional*):
- List of token ids that are not allowed to be generated. In order to get the tokens of the words that
- should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
- num_beams (`int`, *optional*, defaults to 1):
- Number of beams for beam search. 1 means no beam search.
- num_beam_groups (`int`, *optional*, defaults to 1):
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
- beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
- diversity_penalty (`float`, *optional*, defaults to 0.0):
- This value is subtracted from a beam's score if it generates a token same as any beam from other group
- at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
- enabled.
- num_return_sequences(`int`, *optional*, defaults to 1):
- The number of independently computed returned sequences for each element in the batch. Note that this
- is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
- we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
- encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which has the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
@@ -1497,53 +1441,30 @@ def generate(
constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
logits_processor (`LogitsProcessorList`, *optional*):
- Custom logits processors that complement the default logits processors built from arguments and a
- model's config. If a logit processor is passed that is already created with the arguments or a model's
- config an error is thrown.
+ Custom logits processors that complement the default logits processors built from arguments and a
+ model's config. If a logit processor is passed that is already created with the arguments or a model's
+ config an error is thrown.
stopping_criteria (`StoppingCriteriaList`, *optional*):
- Custom stopping criteria that complement the default stopping criteria built from arguments and a
- model's config. If a stopping criteria is passed that is already created with the arguments or a
- model's config an error is thrown.
- forced_bos_token_id (`int`, *optional*):
- The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
- for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
- the target language token.
- forced_eos_token_id (`int`, *optional*):
- The id of the token to force as the last generated token when `max_length` is reached.
- remove_invalid_values (`bool`, *optional*):
- Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
- crash. Note that using `remove_invalid_values` can slow down generation.
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ model's config. If a stopping criteria is passed that is already created with the arguments or a
+ model's config an error is thrown.
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model.
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
finished early due to the `eos_token_id`.
"""
+ # Handle `generation_config` and kwargs that might update it
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
- num_beams = num_beams if num_beams is not None else self.config.num_beams
- num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
- max_length = max_length if max_length is not None else self.config.max_length
- num_return_sequences = (
- num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
- )
- bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- decoder_start_token_id = (
- decoder_start_token_id
- if decoder_start_token_id is not None
- else self.config.generator.decoder_start_token_id
- )
- remove_invalid_values = (
- remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
- )
- exponential_decay_length_penalty = (
- exponential_decay_length_penalty
- if exponential_decay_length_penalty is not None
- else self.config.exponential_decay_length_penalty
- )
# retrieve docs
if self.retriever is not None and context_input_ids is None:
@@ -1583,8 +1504,8 @@ def generate(
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
input_ids = torch.full(
- (batch_size * num_beams, 1),
- decoder_start_token_id,
+ (batch_size * generation_config.num_beams, 1),
+ generation_config.decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
@@ -1600,10 +1521,12 @@ def extend_enc_output(tensor, num_beams=None):
return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
# correctly extend last_hidden_state and attention mask
- context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
- encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
+ context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
+ encoder_outputs["last_hidden_state"] = extend_enc_output(
+ last_hidden_state, num_beams=generation_config.num_beams
+ )
- doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
+ doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
# define start_len & additional parameters
model_kwargs["doc_scores"] = doc_scores
@@ -1612,64 +1535,51 @@ def extend_enc_output(tensor, num_beams=None):
model_kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
+ generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=context_input_ids,
- bad_words_ids=bad_words_ids,
- min_length=min_length,
- max_length=max_length,
- eos_token_id=eos_token_id,
- forced_bos_token_id=forced_bos_token_id,
- forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- num_beams=num_beams,
- num_beam_groups=num_beam_groups,
- diversity_penalty=diversity_penalty,
- remove_invalid_values=remove_invalid_values,
- exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
- renormalize_logits=renormalize_logits,
)
- if num_beams == 1:
- if num_return_sequences > 1:
+ if generation_config.num_beams == 1:
+ if generation_config.num_return_sequences > 1:
raise ValueError(
- f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
)
return self.greedy_search(
input_ids,
logits_processor=pre_processor,
- max_length=max_length,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
+ max_length=generation_config.max_length,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
**model_kwargs,
)
- elif num_beams > 1:
- length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
- early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
- if num_return_sequences > num_beams:
+ elif generation_config.num_beams > 1:
+ if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
- num_beams=num_beams,
+ num_beams=generation_config.num_beams,
device=self.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=pre_processor,
- max_length=max_length,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
+ max_length=generation_config.max_length,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
**model_kwargs,
)
else:
- raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
+ raise ValueError(
+ f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
+ )
def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings()
diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py
index 5cfe0995655f14..004720e110b973 100644
--- a/tests/generation/test_configuration_utils.py
+++ b/tests/generation/test_configuration_utils.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import tempfile
import unittest
from parameterized import parameterized
-from transformers.generation import GenerationConfig
+from transformers import AutoConfig, GenerationConfig
class LogitsProcessorTest(unittest.TestCase):
@@ -43,3 +44,33 @@ def test_save_load_config(self, config_name):
self.assertEqual(loaded_config.top_k, 50)
self.assertEqual(loaded_config.max_length, 20)
self.assertEqual(loaded_config.max_time, None)
+
+ def test_from_model_config(self):
+ model_config = AutoConfig.from_pretrained("gpt2")
+ generation_config_from_model = GenerationConfig.from_model_config(model_config)
+ default_generation_config = GenerationConfig()
+
+ # The generation config has loaded a few non-default parameters from the model config
+ self.assertNotEqual(generation_config_from_model, default_generation_config)
+
+ # One of those parameters is eos_token_id -- check if it matches
+ self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
+ self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
+
+ def test_update(self):
+ generation_config = GenerationConfig()
+ update_kwargs = {
+ "max_new_tokens": 1024,
+ "foo": "bar",
+ }
+ update_kwargs_copy = copy.deepcopy(update_kwargs)
+ unused_kwargs = generation_config.update(**update_kwargs)
+
+ # update_kwargs was not modified (no side effects)
+ self.assertEqual(update_kwargs, update_kwargs_copy)
+
+ # update_kwargs was used to update the config on valid attributes
+ self.assertEqual(generation_config.max_new_tokens, 1024)
+
+ # `.update()` returns a dictionary of unused kwargs
+ self.assertEqual(unused_kwargs, {"foo": "bar"})