diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 637b723c88de20..b5b042e718c1c3 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -841,6 +841,29 @@ def compute_transition_beam_scores( return transition_scores + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.forward).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + @torch.no_grad() def generate( self, @@ -1120,6 +1143,9 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" + # 0. Validate model kwargs + self._validate_model_kwargs(model_kwargs.copy()) + # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 56227403ae60b9..ba13669368d228 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -75,21 +75,25 @@ class GenerationTesterMixin: def _get_input_ids_and_config(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict[self.input_name] - attention_mask = torch.ones_like(input_ids, dtype=torch.long) # cut to half length & take max batch_size 3 max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 input_ids = input_ids[:max_batch_size, :sequence_length] - attention_mask = attention_mask[:max_batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id + + # TransfoXL has no attention mask + if "transfoxl" in config.__class__.__name__.lower(): + attention_mask = None + else: + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length] + return config, input_ids, attention_mask, max_length @staticmethod @@ -252,10 +256,9 @@ def _greedy_generate( ) kwargs = {} - + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, num_beams=1, max_length=max_length, @@ -265,6 +268,7 @@ def _greedy_generate( return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, **logits_process_kwargs, + **model_kwargs, ) if model.config.is_encoder_decoder: @@ -278,16 +282,17 @@ def _greedy_generate( kwargs["encoder_outputs"] = encoder_outputs with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_greedy = model.greedy_search( input_ids, max_length=max_length, - attention_mask=attention_mask, logits_processor=logits_processor, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_greedy, output_generate @@ -308,13 +313,13 @@ def _sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, do_sample=True, num_beams=1, max_length=max_length, num_return_sequences=num_return_sequences, - attention_mask=attention_mask, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -327,7 +332,7 @@ def _sample_generate( torch.manual_seed(0) kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -336,18 +341,16 @@ def _sample_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) - input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) # prevent flaky generation test failures logits_processor.append(InfNanRemoveLogitsProcessor()) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, + input_ids.repeat_interleave(num_return_sequences, dim=0), max_length=max_length, logits_processor=logits_processor, logits_warper=logits_warper, @@ -356,6 +359,7 @@ def _sample_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_sample, output_generate @@ -374,9 +378,9 @@ def _beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -386,12 +390,13 @@ def _beam_search_generate( remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -400,23 +405,22 @@ def _beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_search = model.beam_search( - input_ids_clone, + input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_beam_search @@ -437,9 +441,9 @@ def _beam_sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=True, max_length=max_length, output_scores=output_scores, @@ -449,6 +453,7 @@ def _beam_sample_generate( remove_invalid_values=True, **beam_kwargs, **logits_warper_kwargs, + **model_kwargs, ) # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` kwargs = {} @@ -462,7 +467,7 @@ def _beam_sample_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - else: + elif attention_mask is not None: attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) # prevent flaky generation test failures @@ -471,11 +476,11 @@ def _beam_sample_generate( torch.manual_seed(0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_sample = model.beam_sample( input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), beam_scorer, max_length=max_length, - attention_mask=attention_mask, logits_warper=logits_warper, logits_processor=logits_processor, output_scores=output_scores, @@ -483,6 +488,7 @@ def _beam_sample_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_beam_sample @@ -502,9 +508,9 @@ def _group_beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -514,12 +520,13 @@ def _group_beam_search_generate( remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -528,23 +535,22 @@ def _group_beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_group_beam_search = model.group_beam_search( - input_ids_clone, + input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_group_beam_search @@ -564,9 +570,9 @@ def _constrained_beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -577,12 +583,13 @@ def _constrained_beam_search_generate( constraints=constraints, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -591,23 +598,22 @@ def _constrained_beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_group_beam_search = model.constrained_beam_search( - input_ids_clone, + input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0), constrained_beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_group_beam_search @@ -1044,12 +1050,7 @@ def test_generate_without_input_ids(self): model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate( - do_sample=False, - max_length=max_length, - remove_invalid_values=True, - ) - + output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) self.assertIsNotNone(output_ids_generate) def test_group_beam_search_generate(self): @@ -2052,7 +2053,7 @@ def test_max_new_tokens_decoder_only(self): # max_new_tokens and max_length serve the same purpose and must not be used together. with self.assertRaises(ValueError): - gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) def test_encoder_decoder_generate_with_inputs_embeds(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -2699,3 +2700,19 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaisesRegex(ValueError, "foo"): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs)