Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: validate arguments #18218

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 161 additions & 18 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,14 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to
@staticmethod
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
num_return_sequences: int = 1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this arg as it is a private method, it is more readable, and it is useful for name matching (as you'll see below)

is_encoder_decoder: bool = False,
attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_return_sequences).view(-1).to(input_ids.device)
)
input_ids = input_ids.index_select(0, expanded_return_idx)

Expand Down Expand Up @@ -841,6 +841,80 @@ def compute_transition_beam_scores(

return transition_scores

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling the any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
Comment on lines +846 to +849
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If decoder_input_ids is present, it will be converted to input_ids


# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the
# tests, through)
if "transfoxl" in str(self).lower():
model_kwargs.pop("attention_mask", None)
Comment on lines +851 to +854
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugly ad hoc exception, but the alternative would be to rewrite most GenerationMixin tests for this particular model (most pass attention_mask=attention_mask when calling generate)


unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
Comment on lines +866 to +870
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example of the output for

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, do_samples=True, foo="bar")

Screenshot 2022-07-21 at 14 20 00


def _validate_generation_inputs(
self,
generation_inputs: Dict[str, Any],
generation_method_name: str,
supporting_objects: List[Callable],
Comment on lines +875 to +876
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each generation submethod has a set of preparation functions/classes (hence objects), and some input arguments are consumed there -- we will need their signature to do the correct detection.

):
"""
Validates the generation inputs for the generation submethod. If there are arguments that were passed but will
not be used by the generation method or its supporting objects, an informative exception will be thrown. A
supporting object is a class or function that precedes the generation submethod itself.
"""
# Excludes arguments that are handled before calling the submethods
unused_keys = ["self"]
keys_processed_before_method = [
"inputs",
"bos_token_id",
"max_new_tokens",
"decoder_start_token_id",
"do_sample",
"model_kwargs",
]
keys_passed_to_model_kwargs = ["use_cache", "output_attentions", "output_hidden_states"]
for key in unused_keys + keys_processed_before_method + keys_passed_to_model_kwargs:
generation_inputs.pop(key)

# Validates that the remainder arguments are used by the generation submethod
unused_args = []
generation_method = getattr(self, generation_method_name)
generation_method_args = set(inspect.signature(generation_method).parameters)
for supporting_object in supporting_objects:
generation_method_args |= set(inspect.signature(supporting_object).parameters)
# Ad hoc replacement of supporting object argument to match the corresponding generation argument
if "do_early_stopping" in generation_method_args:
generation_method_args.remove("do_early_stopping")
generation_method_args.add("early_stopping")
Comment on lines +904 to +906
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early_stopping is consumed by a class that has do_early_stopping as argument. Since the class is public, I can't touch it :( At most I can add a new argument doing the same, but probably not worth it.

for key, value in generation_inputs.items():
if value is not None and key not in generation_method_args:
unused_args.append(key)

if unused_args:
raise ValueError(
f"From the generation arguments, `{generation_method_name}` was triggered. The following arguments are"
f" not used by `{generation_method_name}`: {unused_args}. Please remove them from the generation"
" arguments or check the documentation for more information."
)
Comment on lines +912 to +916
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example of the output for

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, num_return_sequences=2, temperature=2.0)

Screenshot 2022-07-21 at 14 32 55


@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1119,6 +1193,10 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Store generation inputs for posterior submethod validation and validate model kwargs
generation_inputs = locals().copy()
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
Expand Down Expand Up @@ -1244,10 +1322,6 @@ def generate(

if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
Comment on lines -1247 to -1250
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to the if is_group_beam_gen_mode: block


# 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
Expand Down Expand Up @@ -1279,10 +1353,11 @@ def generate(

# 9. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
Comment on lines -1282 to -1285
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with the new checks

self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="greedy_search",
supporting_objects=[self._get_logits_processor, self._get_stopping_criteria],
)

# 10. run greedy search
return self.greedy_search(
Expand All @@ -1298,6 +1373,17 @@ def generate(
)

elif is_sample_gen_mode:
self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="sample",
supporting_objects=[
self._get_logits_processor,
self._get_stopping_criteria,
self._get_logits_warper,
self._expand_inputs_for_generation,
],
)

# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
Expand All @@ -1311,7 +1397,7 @@ def generate(
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
num_return_sequences=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
Expand All @@ -1331,6 +1417,17 @@ def generate(
)

elif is_beam_gen_mode:
self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="beam_search",
supporting_objects=[
self._get_logits_processor,
self._get_stopping_criteria,
BeamSearchScorer,
self._expand_inputs_for_generation,
],
)

if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

Expand All @@ -1348,7 +1445,10 @@ def generate(
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
input_ids,
num_return_sequences=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return self.beam_search(
Expand All @@ -1365,6 +1465,18 @@ def generate(
)

elif is_beam_sample_gen_mode:
self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="beam_sample",
supporting_objects=[
self._get_logits_processor,
self._get_stopping_criteria,
self._get_logits_warper,
BeamSearchScorer,
self._expand_inputs_for_generation,
],
)

# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
Expand All @@ -1389,7 +1501,7 @@ def generate(
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_beams * num_return_sequences,
num_return_sequences=num_beams * num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
Expand All @@ -1410,6 +1522,23 @@ def generate(
)

elif is_group_beam_gen_mode:
self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="group_beam_search",
supporting_objects=[
self._get_logits_processor,
self._get_stopping_criteria,
BeamSearchScorer,
self._expand_inputs_for_generation,
],
)

if do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to"
" `False`."
)

if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

Expand All @@ -1432,7 +1561,10 @@ def generate(
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
input_ids,
num_return_sequences=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return self.group_beam_search(
Expand All @@ -1449,6 +1581,17 @@ def generate(
)

elif is_constraint_gen_mode:
self._validate_generation_inputs(
generation_inputs=generation_inputs,
generation_method_name="constrained_beam_search",
supporting_objects=[
self._get_logits_processor,
self._get_stopping_criteria,
ConstrainedBeamSearchScorer,
self._expand_inputs_for_generation,
],
)

if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

Expand All @@ -1461,9 +1604,6 @@ def generate(
if do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")

if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

Comment on lines -1464 to -1466
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant check

final_constraints = []
if constraints is not None:
final_constraints = constraints
Expand Down Expand Up @@ -1513,7 +1653,10 @@ def typeerror():
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
input_ids,
num_return_sequences=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return self.constrained_beam_search(
Expand Down
25 changes: 24 additions & 1 deletion tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self):

# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpt2 is decoder-only


def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down Expand Up @@ -2699,3 +2699,26 @@ def test_constrained_beam_search_mixin_type_checks(self):

with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaises(ValueError):
model.generate(
input_ids,
do_samples=True,
)

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaises(ValueError):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)

# valid args that is not used by the generation submethod (greedy search in this case) also raise an exception
with self.assertRaises(ValueError):
model.generate(input_ids, temperature=2.0)
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_generate_fp16(self):
if torch_device == "cuda":
model.half()
model.generate(**input_dict)
model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3)
model.generate(**input_dict, do_sample=True, num_return_sequences=3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

greedy_search doesn't accept early_stopping


@slow
def test_batched_forward_original_full(self):
Expand Down