Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: validate arguments #18218

Closed
wants to merge 6 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Jul 20, 2022

What does this PR do?

NOTE: this PR is very experimental, feel free to trash it in the review process :)

A common cause for issues in generate is around it not behaving as expected, as arguments can be silently ignored as part of the selected generation submethod (greedy_search, sample, ...). Typos also often fly under the radar, as the method accepts **model_kwargs, which in turn are passed to models that also accept **kwargs.

This PR adds argument validation to generate in two separate steps:

  1. model_kwargs are verified as soon as the method is called. Only arguments that the model actually uses in prepare_inputs_for_generation or in its forward pass are accepted. This means that typos are caught immediately. The exception enumerates all arguments that triggered this failed check, so the user can correct them.
  2. Before calling the appropriate generate submethod, which is picked from the arguments, checks that all passed arguments will actually be used. If the user passes an argument that is not used in that particular submethod, throws an exception indicating the submethod that was triggered and the unaccepted arguments, so the user can fix either problem (correct the submethod or correct the arguments).

Although I think the checks are super useful, the code around it is not the prettiest. The first check has some logic for edge cases, and the second case requires passing the list of methods that will be called before the submethod in question. The PR is heavily commented in GH, feel free to cast your judgment!

P.S.: (seemingly) unrelated accelerate tests are failing in run_examples_torch

Related issues

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 20, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -572,14 +572,14 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to
@staticmethod
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
num_return_sequences: int = 1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this arg as it is a private method, it is more readable, and it is useful for name matching (as you'll see below)

Comment on lines +846 to +849
# Excludes arguments that are handled before calling the any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If decoder_input_ids is present, it will be converted to input_ids

Comment on lines +851 to +854
# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the
# tests, through)
if "transfoxl" in str(self).lower():
model_kwargs.pop("attention_mask", None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugly ad hoc exception, but the alternative would be to rewrite most GenerationMixin tests for this particular model (most pass attention_mask=attention_mask when calling generate)

Comment on lines +866 to +870
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example of the output for

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, do_samples=True, foo="bar")

Screenshot 2022-07-21 at 14 20 00

Comment on lines +875 to +876
generation_method_name: str,
supporting_objects: List[Callable],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each generation submethod has a set of preparation functions/classes (hence objects), and some input arguments are consumed there -- we will need their signature to do the correct detection.

Comment on lines +904 to +906
if "do_early_stopping" in generation_method_args:
generation_method_args.remove("do_early_stopping")
generation_method_args.add("early_stopping")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early_stopping is consumed by a class that has do_early_stopping as argument. Since the class is public, I can't touch it :( At most I can add a new argument doing the same, but probably not worth it.

Comment on lines +912 to +916
raise ValueError(
f"From the generation arguments, `{generation_method_name}` was triggered. The following arguments are"
f" not used by `{generation_method_name}`: {unused_args}. Please remove them from the generation"
" arguments or check the documentation for more information."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example of the output for

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, num_return_sequences=2, temperature=2.0)

Screenshot 2022-07-21 at 14 32 55

Comment on lines -1247 to -1250
if is_group_beam_gen_mode and do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to the if is_group_beam_gen_mode: block

Comment on lines -1282 to -1285
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with the new checks

@@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self):

# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpt2 is decoder-only

@@ -360,7 +360,7 @@ def test_generate_fp16(self):
if torch_device == "cuda":
model.half()
model.generate(**input_dict)
model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3)
model.generate(**input_dict, do_sample=True, num_return_sequences=3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

greedy_search doesn't accept early_stopping

Comment on lines -1464 to -1466
if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant check

@gante gante marked this pull request as ready for review July 21, 2022 13:55
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I think it's a tremendous idea to have better argument validation in generate. This PR tackles that problem in two ways:

  1. validating model kwargs (and thus catching typos)
  2. validating the arguments passed are actually used by the generation algorithm picked.

1 is easier and I have no problem with that part of the PR, maybe it should be done in its first own PR before diving deeper into 2 :-)

For 2, the way you chose feels very very magical with lots of ad-hoc code that is going to be hard to maintain. I wonder if it wouldn't be better to just centralize all input passed like you did in generation_inputs then have each of the private methods of generate returns its regular result as well as the unused inputs, then at the end of the function you can inspect the unused inputs to check they are empty.
I also think this case warrants a warning more than an error by the way.

@gante
Copy link
Member Author

gante commented Jul 22, 2022

For 2, the way you chose feels very very magical with lots of ad-hoc code that is going to be hard to maintain.

Yeah, I agree, that was the number 1 reason why I left so many comments and caveats. It works but would be annoying to maintain.

(@sgugger) If I got it right, the suggestion was to pop used arguments from generation_inputs as we call functions, correct? Something like consume_arguments(generation_inputs, <function that was just called>) after most calls, with a small validation function at the end of generate?

Meanwhile, I'm going to do as suggested, and move the model kwargs validation to its own PR :)

@sgugger
Copy link
Collaborator

sgugger commented Jul 22, 2022

Something like consume_arguments(generation_inputs, ) after most calls, with a small validation function at the end of generate

No, something more like result, generation_inputs = <function to call>(generation_inputs)

@gante
Copy link
Member Author

gante commented Jul 22, 2022

Closing in place of two PRs:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants