From 1142e04768d395fed1628f67930385c8eb0a2733 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Jul 2022 15:26:06 +0000 Subject: [PATCH 1/6] validate generate model_kwargs --- src/transformers/generation_utils.py | 31 +++++++++++++++++++++++ tests/generation/test_generation_utils.py | 21 ++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1792545e454761..4a88737bbafe5d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -841,6 +841,34 @@ def compute_transition_beam_scores( return transition_scores + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # Excludes arguments that are handled before calling the any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + # Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the + # tests, through) + if "transfoxl" in str(self).lower(): + model_kwargs.pop("attention_mask", None) + + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.forward).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + @torch.no_grad() def generate( self, @@ -1119,6 +1147,9 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" + # 0. Validate model kwargs + self._validate_model_kwargs(model_kwargs.copy()) + # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index ce12d631bf39cd..9ca3395291391b 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self): # max_new_tokens and max_length serve the same purpose and should not be used together. with self.assertWarns(UserWarning): - gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) def test_encoder_decoder_generate_with_inputs_embeds(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -2699,3 +2699,22 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaises(ValueError): + model.generate( + input_ids, + do_samples=True, + ) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaises(ValueError): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs) From 0a75fe27b5dd0ba5aaa2d8ecf6f323f350b87115 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Jul 2022 15:33:01 +0000 Subject: [PATCH 2/6] better comments --- src/transformers/generation_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 4a88737bbafe5d..a5e7cff9423aaa 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -843,13 +843,13 @@ def compute_transition_beam_scores( def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # Excludes arguments that are handled before calling the any model function + # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: model_kwargs.pop(key, None) # Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the - # tests, through) + # tests, through, hence this ad hoc exception) if "transfoxl" in str(self).lower(): model_kwargs.pop("attention_mask", None) From 855e1f493e9cef297f92b7a2495a0e52cdb4ccd1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Aug 2022 14:12:10 +0000 Subject: [PATCH 3/6] add PR suggestions --- src/transformers/generation_utils.py | 7 +------ tests/generation/test_generation_utils.py | 16 ++++------------ 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 444619dd21237a..4e08844f0afd75 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -848,14 +848,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): for key in ["decoder_input_ids"]: model_kwargs.pop(key, None) - # Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the - # tests, through, hence this ad hoc exception) - if "transfoxl" in str(self).lower(): - model_kwargs.pop("attention_mask", None) - unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If + # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) if "kwargs" in model_args: model_args |= set(inspect.signature(self.forward).parameters) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 901a1a28951531..d343546b9dae7f 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1044,12 +1044,7 @@ def test_generate_without_input_ids(self): model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate( - do_sample=False, - max_length=max_length, - remove_invalid_values=True, - ) - + output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) self.assertIsNotNone(output_ids_generate) def test_group_beam_search_generate(self): @@ -2708,13 +2703,10 @@ def test_validate_generation_inputs(self): input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids # typos are quickly detected (the correct argument is `do_sample`) - with self.assertRaises(ValueError): - model.generate( - input_ids, - do_samples=True, - ) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) # arbitrary arguments that will not be used anywhere are also not accepted - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "foo"): fake_model_kwargs = {"foo": "bar"} model.generate(input_ids, **fake_model_kwargs) From 0749966678368ede962185bc169ee0fb5722e6c2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Aug 2022 14:53:24 +0000 Subject: [PATCH 4/6] fix transfoxl generation tests (no attention mask) --- src/transformers/benchmark/benchmark_utils.py | 5 ----- src/transformers/generation_flax_utils.py | 1 - src/transformers/generation_tf_utils.py | 5 ----- src/transformers/generation_utils.py | 5 ----- src/transformers/modelcard.py | 3 +-- src/transformers/models/auto/auto_factory.py | 1 - .../models/flaubert/tokenization_flaubert.py | 1 - .../models/fsmt/tokenization_fsmt.py | 1 - .../models/perceiver/modeling_perceiver.py | 1 - .../models/tapex/tokenization_tapex.py | 1 - .../models/transfo_xl/modeling_transfo_xl.py | 1 - .../models/xlm/tokenization_xlm.py | 1 - src/transformers/testing_utils.py | 1 - src/transformers/trainer_pt_utils.py | 1 - src/transformers/trainer_utils.py | 1 - src/transformers/utils/notebook.py | 2 -- tests/generation/test_generation_utils.py | 18 ++++++++++++------ 17 files changed, 13 insertions(+), 36 deletions(-) diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py index 36fe5eb116cbef..79740805807185 100644 --- a/src/transformers/benchmark/benchmark_utils.py +++ b/src/transformers/benchmark/benchmark_utils.py @@ -79,7 +79,6 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b measurements it is important that the function is executed in a separate process Args: - - `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process - `do_multi_processing`: (`bool`) Whether to run function on separate process or not """ @@ -210,7 +209,6 @@ def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_i https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239 Args: - - `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure the peak memory @@ -228,7 +226,6 @@ def get_cpu_memory(process_id: int) -> int: measures current cpu memory usage of a given `process_id` Args: - - `process_id`: (`int`) process_id for which to measure memory Returns @@ -336,7 +333,6 @@ def start_memory_tracing( https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info Args: - - `modules_to_trace`: (None, string, list/tuple of string) if None, all events are recorded if string or list of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or 'transformers.models.gpt2.modeling_gpt2') @@ -483,7 +479,6 @@ def stop_memory_tracing( Stop memory tracing cleanly and return a summary of the memory trace if a trace is given. Args: - `memory_trace` (optional output of start_memory_tracing, default: None): memory trace to convert in summary `ignore_released_memory` (boolean, default: None): diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 2f80c7fcf27e96..fd26a605c48bac 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -208,7 +208,6 @@ def generate( post](https://huggingface.co/blog/how-to-generate). Parameters: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. max_length (`int`, *optional*, defaults to `model.config.max_length`): diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index a3d26b789c646e..6c8da54835ac92 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -418,7 +418,6 @@ def generate( post](https://huggingface.co/blog/how-to-generate). Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length, feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the @@ -1336,7 +1335,6 @@ def _generate( post](https://huggingface.co/blog/how-to-generate). Parameters: - input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): The sequence used as a prompt for the generation. If `None` the method initializes it with `bos_token_id` and a batch size of 1. @@ -2070,7 +2068,6 @@ def greedy_search( Generates sequences for models with a language modeling head using greedy decoding. Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`TFLogitsProcessorList`, *optional*): @@ -2323,7 +2320,6 @@ def sample( Generates sequences for models with a language modeling head using multinomial sampling. Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`TFLogitsProcessorList`, *optional*): @@ -2600,7 +2596,6 @@ def beam_search( Generates sequences for models with a language modeling head using beam search with multinomial sampling. Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. max_length (`int`, *optional*, defaults to 20): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 4e08844f0afd75..9f425f9e2c5c65 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1581,7 +1581,6 @@ def greedy_search( used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): @@ -1815,7 +1814,6 @@ def sample( can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): @@ -2072,7 +2070,6 @@ def beam_search( can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): @@ -2381,7 +2378,6 @@ def beam_sample( sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): @@ -2698,7 +2694,6 @@ def group_beam_search( decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index c5d07e11473778..37b9fa48daa969 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -86,8 +86,7 @@ class ModelCard: Note: A model card can be loaded and saved to disk. - Parameters: - """ + Parameters:""" def __init__(self, **kwargs): warnings.warn( diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index b7d8f66c339dd4..0771f9c8a44b20 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -544,7 +544,6 @@ class _LazyAutoMapping(OrderedDict): " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. Args: - - config_mapping: The map model type to config class - model_mapping: The map model type to model (or tokenizer) class """ diff --git a/src/transformers/models/flaubert/tokenization_flaubert.py b/src/transformers/models/flaubert/tokenization_flaubert.py index 5d5ad2a657d1bc..911ef37dac5046 100644 --- a/src/transformers/models/flaubert/tokenization_flaubert.py +++ b/src/transformers/models/flaubert/tokenization_flaubert.py @@ -130,7 +130,6 @@ def _tokenize(self, text, bypass_tokenizer=False): - Install with `pip install sacremoses` Args: - - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE. diff --git a/src/transformers/models/fsmt/tokenization_fsmt.py b/src/transformers/models/fsmt/tokenization_fsmt.py index 34272e53cf0fcb..66d9819785483c 100644 --- a/src/transformers/models/fsmt/tokenization_fsmt.py +++ b/src/transformers/models/fsmt/tokenization_fsmt.py @@ -354,7 +354,6 @@ def _tokenize(self, text, lang="en", bypass_tokenizer=False): - Install with `pip install sacremoses` Args: - - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it. - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index b3a0beea3d3ca4..d069182f06c3c7 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -1960,7 +1960,6 @@ def build_position_encoding( Builds the position encoding. Args: - - out_channels: refers to the number of channels of the position encodings. - project_pos_dim: if specified, will project the position encodings to this dimension. diff --git a/src/transformers/models/tapex/tokenization_tapex.py b/src/transformers/models/tapex/tokenization_tapex.py index 7c0725ffe7c108..555bf9fd2c6b9a 100644 --- a/src/transformers/models/tapex/tokenization_tapex.py +++ b/src/transformers/models/tapex/tokenization_tapex.py @@ -1398,7 +1398,6 @@ def truncate_table_rows( ): """ Args: - table_content: {"header": xxx, "rows": xxx, "id" (Optionally): xxx} diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 75793466c7a8d1..257c45af03bbc0 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -523,7 +523,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, layer: O weights embeddings afterwards if the model class has a *tie_weights()* method. Arguments: - new_num_tokens: (*optional*) int: New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or None: does nothing and diff --git a/src/transformers/models/xlm/tokenization_xlm.py b/src/transformers/models/xlm/tokenization_xlm.py index bd7b58eb053b0e..8bb021c5b96987 100644 --- a/src/transformers/models/xlm/tokenization_xlm.py +++ b/src/transformers/models/xlm/tokenization_xlm.py @@ -791,7 +791,6 @@ def _tokenize(self, text, lang="en", bypass_tokenizer=False): externally, and set `bypass_tokenizer=True` to bypass the tokenizer. Args: - - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it. - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 80f7bf9c863c87..559e8a0d1a014d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1285,7 +1285,6 @@ def pytest_terminal_summary_main(tr, id): there. Args: - - tr: `terminalreporter` passed from `conftest.py` - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index e1ad471b07a9e0..57103b50d5a039 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -377,7 +377,6 @@ class DistributedTensorGatherer: For some reason, that's not going to roll their boat. This class is there to solve that problem. Args: - world_size (`int`): The number of processes used in the distributed training. num_samples (`int`): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 579e5d1dc24ce4..a298fc1de5719e 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -337,7 +337,6 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None): should be run immediately after the operation to be measured has completed. Args: - - split: name to prefix metric (like train, eval, test...) - start_time: operation start time - num_samples: number of samples processed diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 8d81d76c4fd166..636cf785ea94ea 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -120,7 +120,6 @@ def update(self, value: int, force_update: bool = False, comment: str = None): The main method to update the progress bar to `value`. Args: - value (`int`): The value to use. Must be between 0 and `total`. force_update (`bool`, *optional*, defaults to `False`): @@ -204,7 +203,6 @@ class NotebookTrainingTracker(NotebookProgressBar): An object tracking the updates of an ongoing training with progress bars and a nice table reporting metrics. Args: - num_steps (`int`): The number of steps during training. column_names (`List[str]`, *optional*): The list of column names for the metrics table (will be inferred from the first call to [`~utils.notebook.NotebookTrainingTracker.write_line`] if not set). diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index d343546b9dae7f..e87c20d140dab0 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -75,21 +75,25 @@ class GenerationTesterMixin: def _get_input_ids_and_config(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict[self.input_name] - attention_mask = torch.ones_like(input_ids, dtype=torch.long) # cut to half length & take max batch_size 3 max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 input_ids = input_ids[:max_batch_size, :sequence_length] - attention_mask = attention_mask[:max_batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id + + # TransfoXL has no attention mask + if "transfoxl" in config.__class__.__name__.lower(): + attention_mask = None + else: + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length] + return config, input_ids, attention_mask, max_length @staticmethod @@ -437,9 +441,9 @@ def _beam_sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=True, max_length=max_length, output_scores=output_scores, @@ -449,6 +453,7 @@ def _beam_sample_generate( remove_invalid_values=True, **beam_kwargs, **logits_warper_kwargs, + **model_kwargs, ) # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` kwargs = {} @@ -462,7 +467,7 @@ def _beam_sample_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - else: + elif attention_mask is not None: attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) # prevent flaky generation test failures @@ -471,11 +476,11 @@ def _beam_sample_generate( torch.manual_seed(0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_sample = model.beam_sample( input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), beam_scorer, max_length=max_length, - attention_mask=attention_mask, logits_warper=logits_warper, logits_processor=logits_processor, output_scores=output_scores, @@ -483,6 +488,7 @@ def _beam_sample_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_beam_sample From 97b6f9dc265cce1cb7fe722ff6ab649f75bd7f76 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Aug 2022 15:31:31 +0000 Subject: [PATCH 5/6] generate tests -- not all models have an attn mask --- tests/generation/test_generation_utils.py | 67 ++++++++++++----------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index e87c20d140dab0..d90825141777dd 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -256,10 +256,9 @@ def _greedy_generate( ) kwargs = {} - + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, num_beams=1, max_length=max_length, @@ -269,6 +268,7 @@ def _greedy_generate( return_dict_in_generate=return_dict_in_generate, remove_invalid_values=True, **logits_process_kwargs, + **model_kwargs, ) if model.config.is_encoder_decoder: @@ -282,16 +282,17 @@ def _greedy_generate( kwargs["encoder_outputs"] = encoder_outputs with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_greedy = model.greedy_search( input_ids, max_length=max_length, - attention_mask=attention_mask, logits_processor=logits_processor, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_greedy, output_generate @@ -312,13 +313,13 @@ def _sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, do_sample=True, num_beams=1, max_length=max_length, num_return_sequences=num_return_sequences, - attention_mask=attention_mask, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -331,7 +332,7 @@ def _sample_generate( torch.manual_seed(0) kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -340,18 +341,16 @@ def _sample_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) - input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) # prevent flaky generation test failures logits_processor.append(InfNanRemoveLogitsProcessor()) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, + input_ids.repeat_interleave(num_return_sequences, dim=0), max_length=max_length, logits_processor=logits_processor, logits_warper=logits_warper, @@ -360,6 +359,7 @@ def _sample_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_sample, output_generate @@ -378,9 +378,9 @@ def _beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -390,12 +390,13 @@ def _beam_search_generate( remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -404,23 +405,25 @@ def _beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_search = model.beam_search( input_ids_clone, beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_beam_search @@ -508,9 +511,9 @@ def _group_beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -520,12 +523,13 @@ def _group_beam_search_generate( remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -534,23 +538,22 @@ def _group_beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_group_beam_search = model.group_beam_search( - input_ids_clone, + input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_group_beam_search @@ -570,9 +573,9 @@ def _constrained_beam_search_generate( output_hidden_states=False, return_dict_in_generate=False, ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, - attention_mask=attention_mask, do_sample=False, max_length=max_length, output_scores=output_scores, @@ -583,12 +586,13 @@ def _constrained_beam_search_generate( constraints=constraints, **beam_kwargs, **logits_process_kwargs, + **model_kwargs, ) # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( model, input_ids, attention_mask, @@ -597,23 +601,22 @@ def _constrained_beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_group_beam_search = model.constrained_beam_search( - input_ids_clone, + input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0), constrained_beam_scorer, max_length=max_length, - attention_mask=attention_mask_clone, logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **kwargs, + **model_kwargs, ) return output_generate, output_group_beam_search From 92f7072c74922f6525d922090d41c7d00c19f071 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 11 Aug 2022 15:46:55 +0000 Subject: [PATCH 6/6] fix a few more tests --- tests/generation/test_generation_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index d90825141777dd..ba13669368d228 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -405,16 +405,13 @@ def _beam_search_generate( output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs - input_ids = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - if attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + elif attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_search = model.beam_search( - input_ids_clone, + input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, max_length=max_length, logits_processor=logits_processor,