Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: validate model_kwargs (and catch typos in generate arguments) #18261

Merged
merged 8 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,34 @@ def compute_transition_beam_scores(

return transition_scores

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)

# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the
# tests, through, hence this ad hoc exception)
if "transfoxl" in str(self).lower():
gante marked this conversation as resolved.
Show resolved Hide resolved
model_kwargs.pop("attention_mask", None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If
gante marked this conversation as resolved.
Show resolved Hide resolved
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1119,6 +1147,9 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
Expand Down
21 changes: 20 additions & 1 deletion tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self):

# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)

def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down Expand Up @@ -2699,3 +2699,22 @@ def test_constrained_beam_search_mixin_type_checks(self):

with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaises(ValueError):
gante marked this conversation as resolved.
Show resolved Hide resolved
model.generate(
input_ids,
do_samples=True,
)
gante marked this conversation as resolved.
Show resolved Hide resolved

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaises(ValueError):
gante marked this conversation as resolved.
Show resolved Hide resolved
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)