-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: validate arguments #18218
Generate: validate arguments #18218
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -572,14 +572,14 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to | |
@staticmethod | ||
def _expand_inputs_for_generation( | ||
input_ids: torch.LongTensor, | ||
expand_size: int = 1, | ||
num_return_sequences: int = 1, | ||
is_encoder_decoder: bool = False, | ||
attention_mask: Optional[torch.LongTensor] = None, | ||
encoder_outputs: Optional[ModelOutput] = None, | ||
**model_kwargs, | ||
) -> Tuple[torch.LongTensor, Dict[str, Any]]: | ||
expanded_return_idx = ( | ||
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) | ||
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_return_sequences).view(-1).to(input_ids.device) | ||
) | ||
input_ids = input_ids.index_select(0, expanded_return_idx) | ||
|
||
|
@@ -841,6 +841,80 @@ def compute_transition_beam_scores( | |
|
||
return transition_scores | ||
|
||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): | ||
"""Validates model kwargs for generation. Generate argument typos will also be caught here.""" | ||
# Excludes arguments that are handled before calling the any model function | ||
if self.config.is_encoder_decoder: | ||
for key in ["decoder_input_ids"]: | ||
model_kwargs.pop(key, None) | ||
Comment on lines
+846
to
+849
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
|
||
# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the | ||
# tests, through) | ||
if "transfoxl" in str(self).lower(): | ||
model_kwargs.pop("attention_mask", None) | ||
Comment on lines
+851
to
+854
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ugly ad hoc exception, but the alternative would be to rewrite most GenerationMixin tests for this particular model (most pass |
||
|
||
unused_model_args = [] | ||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) | ||
# `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If | ||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) | ||
if "kwargs" in model_args: | ||
model_args |= set(inspect.signature(self.forward).parameters) | ||
for key, value in model_kwargs.items(): | ||
if value is not None and key not in model_args: | ||
unused_model_args.append(key) | ||
|
||
if unused_model_args: | ||
raise ValueError( | ||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" | ||
" generate arguments will also show up in this list)" | ||
) | ||
Comment on lines
+866
to
+870
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is an example of the output for from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, do_samples=True, foo="bar") |
||
|
||
def _validate_generation_inputs( | ||
self, | ||
generation_inputs: Dict[str, Any], | ||
generation_method_name: str, | ||
supporting_objects: List[Callable], | ||
Comment on lines
+875
to
+876
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each generation submethod has a set of preparation functions/classes (hence |
||
): | ||
""" | ||
Validates the generation inputs for the generation submethod. If there are arguments that were passed but will | ||
not be used by the generation method or its supporting objects, an informative exception will be thrown. A | ||
supporting object is a class or function that precedes the generation submethod itself. | ||
""" | ||
# Excludes arguments that are handled before calling the submethods | ||
unused_keys = ["self"] | ||
keys_processed_before_method = [ | ||
"inputs", | ||
"bos_token_id", | ||
"max_new_tokens", | ||
"decoder_start_token_id", | ||
"do_sample", | ||
"model_kwargs", | ||
] | ||
keys_passed_to_model_kwargs = ["use_cache", "output_attentions", "output_hidden_states"] | ||
for key in unused_keys + keys_processed_before_method + keys_passed_to_model_kwargs: | ||
generation_inputs.pop(key) | ||
|
||
# Validates that the remainder arguments are used by the generation submethod | ||
unused_args = [] | ||
generation_method = getattr(self, generation_method_name) | ||
generation_method_args = set(inspect.signature(generation_method).parameters) | ||
for supporting_object in supporting_objects: | ||
generation_method_args |= set(inspect.signature(supporting_object).parameters) | ||
# Ad hoc replacement of supporting object argument to match the corresponding generation argument | ||
if "do_early_stopping" in generation_method_args: | ||
generation_method_args.remove("do_early_stopping") | ||
generation_method_args.add("early_stopping") | ||
Comment on lines
+904
to
+906
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
for key, value in generation_inputs.items(): | ||
if value is not None and key not in generation_method_args: | ||
unused_args.append(key) | ||
|
||
if unused_args: | ||
raise ValueError( | ||
f"From the generation arguments, `{generation_method_name}` was triggered. The following arguments are" | ||
f" not used by `{generation_method_name}`: {unused_args}. Please remove them from the generation" | ||
" arguments or check the documentation for more information." | ||
) | ||
Comment on lines
+912
to
+916
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is an example of the output for from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, num_return_sequences=2, temperature=2.0) |
||
|
||
@torch.no_grad() | ||
def generate( | ||
self, | ||
|
@@ -1119,6 +1193,10 @@ def generate( | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] | ||
```""" | ||
# 0. Store generation inputs for posterior submethod validation and validate model kwargs | ||
generation_inputs = locals().copy() | ||
self._validate_model_kwargs(model_kwargs.copy()) | ||
|
||
# 1. Set generation parameters if not already defined | ||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | ||
num_beams = num_beams if num_beams is not None else self.config.num_beams | ||
|
@@ -1244,10 +1322,6 @@ def generate( | |
|
||
if num_beam_groups > num_beams: | ||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") | ||
if is_group_beam_gen_mode and do_sample is True: | ||
raise ValueError( | ||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." | ||
) | ||
Comment on lines
-1247
to
-1250
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved to the |
||
|
||
# 7. prepare distribution pre_processing samplers | ||
logits_processor = self._get_logits_processor( | ||
|
@@ -1279,10 +1353,11 @@ def generate( | |
|
||
# 9. go into different generation modes | ||
if is_greedy_gen_mode: | ||
if num_return_sequences > 1: | ||
raise ValueError( | ||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." | ||
) | ||
Comment on lines
-1282
to
-1285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant with the new checks |
||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="greedy_search", | ||
supporting_objects=[self._get_logits_processor, self._get_stopping_criteria], | ||
) | ||
|
||
# 10. run greedy search | ||
return self.greedy_search( | ||
|
@@ -1298,6 +1373,17 @@ def generate( | |
) | ||
|
||
elif is_sample_gen_mode: | ||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="sample", | ||
supporting_objects=[ | ||
self._get_logits_processor, | ||
self._get_stopping_criteria, | ||
self._get_logits_warper, | ||
self._expand_inputs_for_generation, | ||
], | ||
) | ||
|
||
# 10. prepare logits warper | ||
logits_warper = self._get_logits_warper( | ||
top_k=top_k, | ||
|
@@ -1311,7 +1397,7 @@ def generate( | |
# 11. expand input_ids with `num_return_sequences` additional sequences per batch | ||
input_ids, model_kwargs = self._expand_inputs_for_generation( | ||
input_ids, | ||
expand_size=num_return_sequences, | ||
num_return_sequences=num_return_sequences, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
**model_kwargs, | ||
) | ||
|
@@ -1331,6 +1417,17 @@ def generate( | |
) | ||
|
||
elif is_beam_gen_mode: | ||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="beam_search", | ||
supporting_objects=[ | ||
self._get_logits_processor, | ||
self._get_stopping_criteria, | ||
BeamSearchScorer, | ||
self._expand_inputs_for_generation, | ||
], | ||
) | ||
|
||
if num_return_sequences > num_beams: | ||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | ||
|
||
|
@@ -1348,7 +1445,10 @@ def generate( | |
) | ||
# 11. interleave input_ids with `num_beams` additional sequences per batch | ||
input_ids, model_kwargs = self._expand_inputs_for_generation( | ||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs | ||
input_ids, | ||
num_return_sequences=num_beams, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
**model_kwargs, | ||
) | ||
# 12. run beam search | ||
return self.beam_search( | ||
|
@@ -1365,6 +1465,18 @@ def generate( | |
) | ||
|
||
elif is_beam_sample_gen_mode: | ||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="beam_sample", | ||
supporting_objects=[ | ||
self._get_logits_processor, | ||
self._get_stopping_criteria, | ||
self._get_logits_warper, | ||
BeamSearchScorer, | ||
self._expand_inputs_for_generation, | ||
], | ||
) | ||
|
||
# 10. prepare logits warper | ||
logits_warper = self._get_logits_warper( | ||
top_k=top_k, | ||
|
@@ -1389,7 +1501,7 @@ def generate( | |
# 12. interleave input_ids with `num_beams` additional sequences per batch | ||
input_ids, model_kwargs = self._expand_inputs_for_generation( | ||
input_ids, | ||
expand_size=num_beams * num_return_sequences, | ||
num_return_sequences=num_beams * num_return_sequences, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
**model_kwargs, | ||
) | ||
|
@@ -1410,6 +1522,23 @@ def generate( | |
) | ||
|
||
elif is_group_beam_gen_mode: | ||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="group_beam_search", | ||
supporting_objects=[ | ||
self._get_logits_processor, | ||
self._get_stopping_criteria, | ||
BeamSearchScorer, | ||
self._expand_inputs_for_generation, | ||
], | ||
) | ||
|
||
if do_sample is True: | ||
raise ValueError( | ||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to" | ||
" `False`." | ||
) | ||
|
||
if num_return_sequences > num_beams: | ||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | ||
|
||
|
@@ -1432,7 +1561,10 @@ def generate( | |
) | ||
# 11. interleave input_ids with `num_beams` additional sequences per batch | ||
input_ids, model_kwargs = self._expand_inputs_for_generation( | ||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs | ||
input_ids, | ||
num_return_sequences=num_beams, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
**model_kwargs, | ||
) | ||
# 12. run beam search | ||
return self.group_beam_search( | ||
|
@@ -1449,6 +1581,17 @@ def generate( | |
) | ||
|
||
elif is_constraint_gen_mode: | ||
self._validate_generation_inputs( | ||
generation_inputs=generation_inputs, | ||
generation_method_name="constrained_beam_search", | ||
supporting_objects=[ | ||
self._get_logits_processor, | ||
self._get_stopping_criteria, | ||
ConstrainedBeamSearchScorer, | ||
self._expand_inputs_for_generation, | ||
], | ||
) | ||
|
||
if num_return_sequences > num_beams: | ||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | ||
|
||
|
@@ -1461,9 +1604,6 @@ def generate( | |
if do_sample: | ||
raise ValueError("`do_sample` needs to be false for constrained generation.") | ||
|
||
if num_beam_groups is not None and num_beam_groups > 1: | ||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.") | ||
|
||
Comment on lines
-1464
to
-1466
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant check |
||
final_constraints = [] | ||
if constraints is not None: | ||
final_constraints = constraints | ||
|
@@ -1513,7 +1653,10 @@ def typeerror(): | |
) | ||
# 11. interleave input_ids with `num_beams` additional sequences per batch | ||
input_ids, model_kwargs = self._expand_inputs_for_generation( | ||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs | ||
input_ids, | ||
num_return_sequences=num_beams, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
**model_kwargs, | ||
) | ||
# 12. run beam search | ||
return self.constrained_beam_search( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self): | |
|
||
# max_new_tokens and max_length serve the same purpose and should not be used together. | ||
with self.assertWarns(UserWarning): | ||
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) | ||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gpt2 is decoder-only |
||
|
||
def test_encoder_decoder_generate_with_inputs_embeds(self): | ||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" | ||
|
@@ -2699,3 +2699,26 @@ def test_constrained_beam_search_mixin_type_checks(self): | |
|
||
with self.assertRaises(ValueError): | ||
model.generate(input_ids, force_words_ids=[[[-1]]]) | ||
|
||
def test_validate_generation_inputs(self): | ||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") | ||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") | ||
|
||
encoder_input_str = "Hello world" | ||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | ||
|
||
# typos are quickly detected (the correct argument is `do_sample`) | ||
with self.assertRaises(ValueError): | ||
model.generate( | ||
input_ids, | ||
do_samples=True, | ||
) | ||
|
||
# arbitrary arguments that will not be used anywhere are also not accepted | ||
with self.assertRaises(ValueError): | ||
fake_model_kwargs = {"foo": "bar"} | ||
model.generate(input_ids, **fake_model_kwargs) | ||
|
||
# valid args that is not used by the generation submethod (greedy search in this case) also raise an exception | ||
with self.assertRaises(ValueError): | ||
model.generate(input_ids, temperature=2.0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,7 +360,7 @@ def test_generate_fp16(self): | |
if torch_device == "cuda": | ||
model.half() | ||
model.generate(**input_dict) | ||
model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3) | ||
model.generate(**input_dict, do_sample=True, num_return_sequences=3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. greedy_search doesn't accept early_stopping |
||
|
||
@slow | ||
def test_batched_forward_original_full(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed this arg as it is a private method, it is more readable, and it is useful for name matching (as you'll see below)