From d55ea06de0e36f093070f9fd86bf9e9847db9c51 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Jul 2022 17:03:43 +0000 Subject: [PATCH 1/6] validate generation arguments --- src/transformers/generation_utils.py | 168 +++++++++++++++++++--- tests/generation/test_generation_utils.py | 23 +++ 2 files changed, 170 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1792545e454761..25c3b86267eeee 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -572,14 +572,14 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, - expand_size: int = 1, + num_return_sequences: int = 1, is_encoder_decoder: bool = False, attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[ModelOutput] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( - torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_return_sequences).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) @@ -841,6 +841,66 @@ def compute_transition_beam_scores( return transition_scores + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation.""" + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.forward).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + + def _validate_generation_inputs( + self, + generation_inputs: Dict[str, Any], + generation_method_name: str, + supporting_objects: List[Callable], + ): + """ + Validates the generation inputs for the generation submethod. If there are arguments that were passed but will + not be used by the generation method or its supporting objects, an informative exception will be thrown. A + supporting object is a class or function that precedes the generation submethod itself. + """ + # Excludes arguments that are handled before calling the submethods + unused_keys = ["self"] + keys_processed_before_method = [ + "inputs", + "bos_token_id", + "max_new_tokens", + "decoder_start_token_id", + "do_sample", + "model_kwargs", + ] + keys_passed_to_model_kwargs = ["use_cache", "output_attentions", "output_hidden_states"] + for key in unused_keys + keys_processed_before_method + keys_passed_to_model_kwargs: + generation_inputs.pop(key) + + # Validates that the remainder arguments are used by the generation submethod + unused_args = [] + generation_method = getattr(self, generation_method_name) + generation_method_args = set(inspect.signature(generation_method).parameters) + for supporting_object in supporting_objects: + generation_method_args |= set(inspect.signature(supporting_object).parameters) + for key, value in generation_inputs.items(): + if value is not None and key not in generation_method_args: + unused_args.append(key) + + if unused_args: + raise ValueError( + f"From the generation arguments, `{generation_method_name}` was triggered. The following arguments are" + f" not used by `{generation_method_name}`: {unused_args}. Please remove them from the generation" + " arguments or check the documentation for more information." + ) + @torch.no_grad() def generate( self, @@ -1119,6 +1179,10 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" + # 0. Store generation inputs for posterior submethod validation and validate model kwargs + generation_inputs = locals().copy() + self._validate_model_kwargs(model_kwargs) + # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams @@ -1244,10 +1308,6 @@ def generate( if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") - if is_group_beam_gen_mode and do_sample is True: - raise ValueError( - "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." - ) # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( @@ -1279,10 +1339,11 @@ def generate( # 9. go into different generation modes if is_greedy_gen_mode: - if num_return_sequences > 1: - raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." - ) + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="greedy_search", + supporting_objects=[self._get_logits_processor, self._get_stopping_criteria], + ) # 10. run greedy search return self.greedy_search( @@ -1298,6 +1359,17 @@ def generate( ) elif is_sample_gen_mode: + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="sample", + supporting_objects=[ + self._get_logits_processor, + self._get_stopping_criteria, + self._get_logits_warper, + self._expand_inputs_for_generation, + ], + ) + # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, @@ -1311,7 +1383,7 @@ def generate( # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, - expand_size=num_return_sequences, + num_return_sequences=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1331,6 +1403,17 @@ def generate( ) elif is_beam_gen_mode: + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="beam_search", + supporting_objects=[ + self._get_logits_processor, + self._get_stopping_criteria, + BeamSearchScorer, + self._expand_inputs_for_generation, + ], + ) + if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1348,7 +1431,10 @@ def generate( ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + input_ids, + num_return_sequences=num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, ) # 12. run beam search return self.beam_search( @@ -1365,6 +1451,18 @@ def generate( ) elif is_beam_sample_gen_mode: + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="beam_sample", + supporting_objects=[ + self._get_logits_processor, + self._get_stopping_criteria, + self._get_logits_warper, + BeamSearchScorer, + self._expand_inputs_for_generation, + ], + ) + # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, @@ -1389,7 +1487,7 @@ def generate( # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, - expand_size=num_beams * num_return_sequences, + num_return_sequences=num_beams * num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1410,6 +1508,23 @@ def generate( ) elif is_group_beam_gen_mode: + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="group_beam_search", + supporting_objects=[ + self._get_logits_processor, + self._get_stopping_criteria, + BeamSearchScorer, + self._expand_inputs_for_generation, + ], + ) + + if do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to" + " `False`." + ) + if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1432,7 +1547,10 @@ def generate( ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + input_ids, + num_return_sequences=num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, ) # 12. run beam search return self.group_beam_search( @@ -1449,6 +1567,17 @@ def generate( ) elif is_constraint_gen_mode: + self._validate_generation_inputs( + generation_inputs=generation_inputs, + generation_method_name="constrained_beam_search", + supporting_objects=[ + self._get_logits_processor, + self._get_stopping_criteria, + ConstrainedBeamSearchScorer, + self._expand_inputs_for_generation, + ], + ) + if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1458,12 +1587,6 @@ def generate( if num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") - if do_sample: - raise ValueError("`do_sample` needs to be false for constrained generation.") - - if num_beam_groups is not None and num_beam_groups > 1: - raise ValueError("`num_beam_groups` not supported yet for constrained generation.") - final_constraints = [] if constraints is not None: final_constraints = constraints @@ -1513,7 +1636,10 @@ def typeerror(): ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + input_ids, + num_return_sequences=num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, ) # 12. run beam search return self.constrained_beam_search( diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index ce12d631bf39cd..40c05a8ca1f502 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2699,3 +2699,26 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaises(ValueError): + model.generate( + input_ids, + do_samples=True, + ) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaises(ValueError): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs) + + # valid args that is not used by the generation submethod (greedy search in this case) also raise an exception + with self.assertRaises(ValueError): + model.generate(input_ids, temperature=2.0) From 53101c5a6ed050d7d9c28e091fbdb8075f6bdcf2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Jul 2022 17:44:36 +0000 Subject: [PATCH 2/6] Add some fixes --- src/transformers/generation_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 25c3b86267eeee..919617f3467815 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -843,6 +843,11 @@ def compute_transition_beam_scores( def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation.""" + # Excludes arguments that are handled before calling the any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key) + unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) # `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If @@ -890,6 +895,10 @@ def _validate_generation_inputs( generation_method_args = set(inspect.signature(generation_method).parameters) for supporting_object in supporting_objects: generation_method_args |= set(inspect.signature(supporting_object).parameters) + # Ad hoc replacement of supporting object argument to match the corresponding generation argument + if "do_early_stopping" in generation_method_args: + generation_method_args.remove("do_early_stopping") + generation_method_args.add("early_stopping") for key, value in generation_inputs.items(): if value is not None and key not in generation_method_args: unused_args.append(key) @@ -1181,7 +1190,7 @@ def generate( ```""" # 0. Store generation inputs for posterior submethod validation and validate model kwargs generation_inputs = locals().copy() - self._validate_model_kwargs(model_kwargs) + self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id From db88b11d8d0ea07569e25ba972108c2e438a8b19 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Jul 2022 17:45:59 +0000 Subject: [PATCH 3/6] improve docstring --- src/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 919617f3467815..5b08298ac213f6 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -842,7 +842,7 @@ def compute_transition_beam_scores( return transition_scores def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - """Validates model kwargs for generation.""" + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling the any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: From 8893aa2d23df1e6fbbf44001b2759544623fb7d7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Jul 2022 17:57:26 +0000 Subject: [PATCH 4/6] add default to pop --- src/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 5b08298ac213f6..1b93c5cb6d8ee1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -846,7 +846,7 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): # Excludes arguments that are handled before calling the any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: - model_kwargs.pop(key) + model_kwargs.pop(key, None) unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) From 2cdf84b960eff18a38cc506ab4a0c26d5d579032 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 21 Jul 2022 11:46:03 +0000 Subject: [PATCH 5/6] add exceptions --- src/transformers/generation_utils.py | 5 +++++ tests/generation/test_generation_utils.py | 2 +- .../models/bigbird_pegasus/test_modeling_bigbird_pegasus.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1b93c5cb6d8ee1..2f5a4ca6433c7a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -848,6 +848,11 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): for key in ["decoder_input_ids"]: model_kwargs.pop(key, None) + # Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the + # tests, through) + if "transfoxl" in str(self).lower(): + model_kwargs.pop("attention_mask", None) + unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) # `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 40c05a8ca1f502..53d9d78541dc81 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self): # max_new_tokens and max_length serve the same purpose and should not be used together. with self.assertWarns(UserWarning): - gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) def test_encoder_decoder_generate_with_inputs_embeds(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index d4e7e8f4ae422a..d6aa7b505ff477 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -360,7 +360,7 @@ def test_generate_fp16(self): if torch_device == "cuda": model.half() model.generate(**input_dict) - model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3) + model.generate(**input_dict, do_sample=True, num_return_sequences=3) @slow def test_batched_forward_original_full(self): From 017e77fa13db5cf4c393ed7492995a8b74425a20 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 21 Jul 2022 13:36:58 +0000 Subject: [PATCH 6/6] Add removed check --- src/transformers/generation_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 2f5a4ca6433c7a..3ddce6d102c34a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1601,6 +1601,9 @@ def generate( if num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") + if do_sample: + raise ValueError("`do_sample` needs to be false for constrained generation.") + final_constraints = [] if constraints is not None: final_constraints = constraints