Skip to content

Commit

Permalink
Generate: validate model_kwargs (and catch typos in generate argume…
Browse files Browse the repository at this point in the history
…nts) (huggingface#18261)

* validate generate model_kwargs

* generate tests -- not all models have an attn mask
  • Loading branch information
gante authored and oneraghavan committed Sep 26, 2022
1 parent 4c5b6d2 commit 87fd26f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 48 deletions.
26 changes: 26 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,29 @@ def compute_transition_beam_scores(

return transition_scores

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1120,6 +1143,9 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
Expand Down
Loading

0 comments on commit 87fd26f

Please sign in to comment.